package ai.vespa.rankingexpression.importer.tensorflow;

import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.Argument;
import ai.vespa.rankingexpression.importer.operations.ConcatV2;
import ai.vespa.rankingexpression.importer.operations.Const;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Identity;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.Join;
import ai.vespa.rankingexpression.importer.operations.Map;
import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.Mean;
import ai.vespa.rankingexpression.importer.operations.Merge;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Shape;
import ai.vespa.rankingexpression.importer.operations.Softmax;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
import ai.vespa.rankingexpression.importer.operations.Sum;
import ai.vespa.rankingexpression.importer.operations.Switch;
import ai.vespa.rankingexpression.importer.vespa.parser.ModelParserConstants;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.functions.ScalarFunctions;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.class */
class GraphImporter {
    GraphImporter() {
    }

    private static IntermediateOperation mapOperation(NodeDef nodeDef, List<IntermediateOperation> list, IntermediateGraph intermediateGraph) {
        String name = nodeDef.getName();
        String name2 = intermediateGraph.name();
        int indexPartOf = IntermediateOperation.indexPartOf(name);
        OrderedTensorType typeFrom = TypeConverter.typeFrom(nodeDef);
        AttributeConverter convert = AttributeConverter.convert(nodeDef);
        String lowerCase = nodeDef.getOp().toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -2060248300:
                if (lowerCase.equals("subtract")) {
                    z = 39;
                    break;
                }
                break;
            case -2035660550:
                if (lowerCase.equals("softmax")) {
                    z = 49;
                    break;
                }
                break;
            case -1965003230:
                if (lowerCase.equals("squeeze")) {
                    z = 8;
                    break;
                }
                break;
            case -1249586564:
                if (lowerCase.equals("variable")) {
                    z = 50;
                    break;
                }
                break;
            case -1214978517:
                if (lowerCase.equals("reducemean")) {
                    z = 27;
                    break;
                }
                break;
            case -1148265166:
                if (lowerCase.equals("stopgradient")) {
                    z = 52;
                    break;
                }
                break;
            case -1081244060:
                if (lowerCase.equals("matmul")) {
                    z = 24;
                    break;
                }
                break;
            case -1049319678:
                if (lowerCase.equals("negate")) {
                    z = 30;
                    break;
                }
                break;
            case -906021636:
                if (lowerCase.equals("select")) {
                    z = 33;
                    break;
                }
                break;
            case -894674659:
                if (lowerCase.equals("square")) {
                    z = 41;
                    break;
                }
                break;
            case -889473228:
                if (lowerCase.equals("switch")) {
                    z = 10;
                    break;
                }
                break;
            case -788930516:
                if (lowerCase.equals("where3")) {
                    z = 34;
                    break;
                }
                break;
            case -583135088:
                if (lowerCase.equals("concatv2")) {
                    z = false;
                    break;
                }
                break;
            case -338688742:
                if (lowerCase.equals("reciprocal")) {
                    z = 31;
                    break;
                }
                break;
            case -135761730:
                if (lowerCase.equals("identity")) {
                    z = 3;
                    break;
                }
                break;
            case -119787704:
                if (lowerCase.equals("biasadd")) {
                    z = 45;
                    break;
                }
                break;
            case 96370:
                if (lowerCase.equals("abs")) {
                    z = 11;
                    break;
                }
                break;
            case 96417:
                if (lowerCase.equals("add")) {
                    z = 13;
                    break;
                }
                break;
            case 98695:
                if (lowerCase.equals("cos")) {
                    z = 18;
                    break;
                }
                break;
            case 99473:
                if (lowerCase.equals("div")) {
                    z = 19;
                    break;
                }
                break;
            case 100526:
                if (lowerCase.equals("elu")) {
                    z = 46;
                    break;
                }
                break;
            case 100893:
                if (lowerCase.equals("exp")) {
                    z = 20;
                    break;
                }
                break;
            case 107332:
                if (lowerCase.equals("log")) {
                    z = 23;
                    break;
                }
                break;
            case 108484:
                if (lowerCase.equals("mul")) {
                    z = 28;
                    break;
                }
                break;
            case 113880:
                if (lowerCase.equals("sin")) {
                    z = 36;
                    break;
                }
                break;
            case 114240:
                if (lowerCase.equals("sub")) {
                    z = 38;
                    break;
                }
                break;
            case 114251:
                if (lowerCase.equals("sum")) {
                    z = 40;
                    break;
                }
                break;
            case 114593:
                if (lowerCase.equals("tan")) {
                    z = 43;
                    break;
                }
                break;
            case 2988422:
                if (lowerCase.equals("acos")) {
                    z = 12;
                    break;
                }
                break;
            case 3003607:
                if (lowerCase.equals("asin")) {
                    z = 15;
                    break;
                }
                break;
            case 3004320:
                if (lowerCase.equals("atan")) {
                    z = 16;
                    break;
                }
                break;
            case 3049733:
                if (lowerCase.equals("ceil")) {
                    z = 17;
                    break;
                }
                break;
            case 3347397:
                if (lowerCase.equals("mean")) {
                    z = 26;
                    break;
                }
                break;
            case 3387234:
                if (lowerCase.equals("noop")) {
                    z = 53;
                    break;
                }
                break;
            case 3496700:
                if (lowerCase.equals("relu")) {
                    z = 47;
                    break;
                }
                break;
            case 3526491:
                if (lowerCase.equals("selu")) {
                    z = 48;
                    break;
                }
                break;
            case 3538208:
                if (lowerCase.equals("sqrt")) {
                    z = 42;
                    break;
                }
                break;
            case 3552487:
                if (lowerCase.equals("tanh")) {
                    z = 44;
                    break;
                }
                break;
            case 92659792:
                if (lowerCase.equals("add_n")) {
                    z = 14;
                    break;
                }
                break;
            case 94844771:
                if (lowerCase.equals("const")) {
                    z = true;
                    break;
                }
                break;
            case 97526796:
                if (lowerCase.equals("floor")) {
                    z = 22;
                    break;
                }
                break;
            case 103785528:
                if (lowerCase.equals("merge")) {
                    z = 9;
                    break;
                }
                break;
            case 108819602:
                if (lowerCase.equals("rsqrt")) {
                    z = 32;
                    break;
                }
                break;
            case 109399969:
                if (lowerCase.equals("shape")) {
                    z = 7;
                    break;
                }
                break;
            case 540216965:
                if (lowerCase.equals("expanddims")) {
                    z = 2;
                    break;
                }
                break;
            case 598246771:
                if (lowerCase.equals("placeholder")) {
                    z = 4;
                    break;
                }
                break;
            case 653829668:
                if (lowerCase.equals("multiply")) {
                    z = 29;
                    break;
                }
                break;
            case 844740128:
                if (lowerCase.equals("maximum")) {
                    z = 25;
                    break;
                }
                break;
            case 1080647219:
                if (lowerCase.equals("realdiv")) {
                    z = 21;
                    break;
                }
                break;
            case 1097148750:
                if (lowerCase.equals("reshape")) {
                    z = 6;
                    break;
                }
                break;
            case 1137462116:
                if (lowerCase.equals("squareddifference")) {
                    z = 37;
                    break;
                }
                break;
            case 1738158584:
                if (lowerCase.equals("variablev2")) {
                    z = 51;
                    break;
                }
                break;
            case 1779335016:
                if (lowerCase.equals("placeholderwithdefault")) {
                    z = 5;
                    break;
                }
                break;
            case 2088248974:
                if (lowerCase.equals("sigmoid")) {
                    z = 35;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new ConcatV2(name2, name, list);
            case true:
                return new Const(name2, name, list, convert, typeFrom);
            case true:
                return new ExpandDims(name2, name, list);
            case true:
                return new Identity(name2, name, list);
            case true:
                return new Argument(name2, name, typeFrom);
            case true:
                return new PlaceholderWithDefault(name2, name, list);
            case true:
                return new Reshape(name2, name, list);
            case true:
                return new Shape(name2, name, list);
            case true:
                return new Squeeze(name2, name, list, convert);
            case true:
                return new Merge(name2, name, list);
            case true:
                return new Switch(name2, name, list, indexPartOf);
            case true:
                return new Map(name2, name, list, ScalarFunctions.abs());
            case true:
                return new Map(name2, name, list, ScalarFunctions.acos());
            case true:
                return new Join(name2, name, list, ScalarFunctions.add());
            case true:
                return new Join(name2, name, list, ScalarFunctions.add());
            case true:
                return new Map(name2, name, list, ScalarFunctions.asin());
            case ModelParserConstants.MODEL /* 16 */:
                return new Map(name2, name, list, ScalarFunctions.atan());
            case ModelParserConstants.TYPE /* 17 */:
                return new Map(name2, name, list, ScalarFunctions.ceil());
            case ModelParserConstants.EXPRESSION_SL /* 18 */:
                return new Map(name2, name, list, ScalarFunctions.cos());
            case ModelParserConstants.EXPRESSION_ML /* 19 */:
                return new Join(name2, name, list, ScalarFunctions.divide());
            case true:
                return new Map(name2, name, list, ScalarFunctions.exp());
            case true:
                return new Join(name2, name, list, ScalarFunctions.divide());
            case ModelParserConstants.BRACE_SL_LEVEL_3 /* 22 */:
                return new Map(name2, name, list, ScalarFunctions.floor());
            case ModelParserConstants.BRACE_SL_CONTENT /* 23 */:
                return new Map(name2, name, list, ScalarFunctions.log());
            case ModelParserConstants.BRACE_ML_LEVEL_1 /* 24 */:
                return new MatMul(name2, name, list);
            case ModelParserConstants.BRACE_ML_LEVEL_2 /* 25 */:
                return new Join(name2, name, list, ScalarFunctions.max());
            case ModelParserConstants.BRACE_ML_LEVEL_3 /* 26 */:
                return new Mean(name2, name, list, convert);
            case ModelParserConstants.BRACE_ML_CONTENT /* 27 */:
                return new Mean(name2, name, list, convert);
            case ModelParserConstants.SEARCHLIB_SKIP /* 28 */:
                return new Join(name2, name, list, ScalarFunctions.multiply());
            case ModelParserConstants.CONSTANT /* 29 */:
                return new Join(name2, name, list, ScalarFunctions.multiply());
            case ModelParserConstants.CONSTANTS /* 30 */:
                return new Map(name2, name, list, ScalarFunctions.neg());
            case ModelParserConstants.FILE /* 31 */:
                return new Map(name2, name, list, ScalarFunctions.reciprocal());
            case ModelParserConstants.URI /* 32 */:
                return new Map(name2, name, list, ScalarFunctions.rsqrt());
            case ModelParserConstants.IDENTIFIER /* 33 */:
                return new Select(name2, name, list);
            case ModelParserConstants.CONTEXT /* 34 */:
                return new Select(name2, name, list);
            case ModelParserConstants.DOUBLE /* 35 */:
                return new Map(name2, name, list, ScalarFunctions.sigmoid());
            case ModelParserConstants.STRING /* 36 */:
                return new Map(name2, name, list, ScalarFunctions.sin());
            case ModelParserConstants.FILE_PATH /* 37 */:
                return new Join(name2, name, list, ScalarFunctions.squareddifference());
            case ModelParserConstants.HTTP /* 38 */:
                return new Join(name2, name, list, ScalarFunctions.subtract());
            case ModelParserConstants.URI_PATH /* 39 */:
                return new Join(name2, name, list, ScalarFunctions.subtract());
            case ModelParserConstants.SINGLE_LINE_COMMENT /* 40 */:
                return new Sum(name2, name, list, convert);
            case true:
                return new Map(name2, name, list, ScalarFunctions.square());
            case true:
                return new Map(name2, name, list, ScalarFunctions.sqrt());
            case true:
                return new Map(name2, name, list, ScalarFunctions.tan());
            case true:
                return new Map(name2, name, list, ScalarFunctions.tanh());
            case true:
                return new Join(name2, name, list, ScalarFunctions.add());
            case true:
                return new Map(name2, name, list, ScalarFunctions.elu());
            case true:
                return new Map(name2, name, list, ScalarFunctions.relu());
            case true:
                return new Map(name2, name, list, ScalarFunctions.selu());
            case true:
                return new Softmax(name2, name, list);
            case true:
                return new Constant(name2, name, typeFrom);
            case true:
                return new Constant(name2, name, typeFrom);
            case true:
                return new Identity(name2, name, list);
            case true:
                return new NoOp(name2, name, list);
            default:
                NoOp noOp = new NoOp(name2, nodeDef.getName(), list);
                noOp.warning("Operation '" + nodeDef.getOp() + "' is currently not implemented");
                return noOp;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IntermediateGraph importGraph(String str, SavedModelBundle savedModelBundle) throws IOException {
        MetaGraphDef parseFrom = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef());
        IntermediateGraph intermediateGraph = new IntermediateGraph(str);
        importSignatures(parseFrom, intermediateGraph);
        importOperations(parseFrom, intermediateGraph, savedModelBundle);
        verifyOutputTypes(parseFrom, intermediateGraph);
        return intermediateGraph;
    }

    private static void importSignatures(MetaGraphDef metaGraphDef, IntermediateGraph intermediateGraph) {
        for (Map.Entry entry : metaGraphDef.getSignatureDefMap().entrySet()) {
            String str = (String) entry.getKey();
            for (Map.Entry entry2 : ((SignatureDef) entry.getValue()).getInputsMap().entrySet()) {
                intermediateGraph.inputs(str).put((String) entry2.getKey(), IntermediateOperation.namePartOf(((TensorInfo) entry2.getValue()).getName()));
            }
            for (Map.Entry entry3 : ((SignatureDef) entry.getValue()).getOutputsMap().entrySet()) {
                intermediateGraph.outputs(str).put((String) entry3.getKey(), IntermediateOperation.namePartOf(((TensorInfo) entry3.getValue()).getName()));
            }
        }
    }

    private static void importOperations(MetaGraphDef metaGraphDef, IntermediateGraph intermediateGraph, SavedModelBundle savedModelBundle) {
        Iterator<String> it = intermediateGraph.signatures().iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = intermediateGraph.outputs(it.next()).values().iterator();
            while (it2.hasNext()) {
                importOperation(it2.next(), metaGraphDef.getGraphDef(), intermediateGraph, savedModelBundle);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static IntermediateOperation importOperation(String str, GraphDef graphDef, IntermediateGraph intermediateGraph, SavedModelBundle savedModelBundle) {
        if (intermediateGraph.alreadyImported(str)) {
            return intermediateGraph.get(str);
        }
        NodeDef tensorFlowNodeFromGraph = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(str), graphDef);
        IntermediateOperation mapOperation = mapOperation(tensorFlowNodeFromGraph, importOperationInputs(tensorFlowNodeFromGraph, graphDef, intermediateGraph, savedModelBundle), intermediateGraph);
        intermediateGraph.put(str, mapOperation);
        List<IntermediateOperation> importControlInputs = importControlInputs(tensorFlowNodeFromGraph, graphDef, intermediateGraph, savedModelBundle);
        if (importControlInputs.size() > 0) {
            mapOperation.setControlInputs(importControlInputs);
        }
        if (mapOperation.isConstant()) {
            mapOperation.setConstantValueFunction(orderedTensorType -> {
                return new TensorValue(TensorConverter.toVespaTensor(readVariable(str, savedModelBundle), orderedTensorType));
            });
        }
        return mapOperation;
    }

    private static List<IntermediateOperation> importOperationInputs(NodeDef nodeDef, GraphDef graphDef, IntermediateGraph intermediateGraph, SavedModelBundle savedModelBundle) {
        return (List) nodeDef.getInputList().stream().filter(str -> {
            return !isControlDependency(str);
        }).map(str2 -> {
            return importOperation(str2, graphDef, intermediateGraph, savedModelBundle);
        }).collect(Collectors.toList());
    }

    private static List<IntermediateOperation> importControlInputs(NodeDef nodeDef, GraphDef graphDef, IntermediateGraph intermediateGraph, SavedModelBundle savedModelBundle) {
        return (List) nodeDef.getInputList().stream().filter(str -> {
            return isControlDependency(str);
        }).map(str2 -> {
            return importOperation(str2, graphDef, intermediateGraph, savedModelBundle);
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isControlDependency(String str) {
        return str.startsWith("^");
    }

    private static NodeDef getTensorFlowNodeFromGraph(String str, GraphDef graphDef) {
        for (NodeDef nodeDef : graphDef.getNodeList()) {
            if (nodeDef.getName().equals(str)) {
                return nodeDef;
            }
        }
        throw new IllegalArgumentException("Could not find node '" + str + "'");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor<?> readVariable(String str, SavedModelBundle savedModelBundle) {
        List run = savedModelBundle.session().runner().fetch(str).run();
        if (run.size() != 1) {
            throw new IllegalStateException("Expected 1 tensor from fetching " + str + ", but got " + run.size());
        }
        return (Tensor) run.get(0);
    }

    private static void verifyOutputTypes(MetaGraphDef metaGraphDef, IntermediateGraph intermediateGraph) {
        Iterator<String> it = intermediateGraph.signatures().iterator();
        while (it.hasNext()) {
            for (String str : intermediateGraph.outputs(it.next()).values()) {
                IntermediateOperation intermediateOperation = intermediateGraph.get(str);
                TypeConverter.verifyType(getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(intermediateOperation.name()), metaGraphDef.getGraphDef()), intermediateOperation.type().orElseThrow(() -> {
                    return new IllegalArgumentException("Output of '" + str + "' has no type.");
                }));
            }
        }
    }
}
