package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.operations.Argument;
import ai.vespa.rankingexpression.importer.operations.ConcatV2;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.Identity;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.Join;
import ai.vespa.rankingexpression.importer.operations.Map;
import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
import ai.vespa.rankingexpression.importer.vespa.parser.ModelParserConstants;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.functions.ScalarFunctions;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import onnx.Onnx;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/GraphImporter.class */
class GraphImporter {
    GraphImporter() {
    }

    private static IntermediateOperation mapOperation(Onnx.NodeProto nodeProto, List<IntermediateOperation> list, IntermediateGraph intermediateGraph) {
        String name = intermediateGraph.name();
        String nodeName = getNodeName(nodeProto);
        String lowerCase = nodeProto.getOpType().toLowerCase();
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -1354795244:
                if (lowerCase.equals("concat")) {
                    z = 6;
                    break;
                }
                break;
            case -1081244060:
                if (lowerCase.equals("matmul")) {
                    z = 17;
                    break;
                }
                break;
            case -894674659:
                if (lowerCase.equals("square")) {
                    z = 33;
                    break;
                }
                break;
            case -338688742:
                if (lowerCase.equals("reciprocal")) {
                    z = 25;
                    break;
                }
                break;
            case -135761730:
                if (lowerCase.equals("identity")) {
                    z = 14;
                    break;
                }
                break;
            case 96370:
                if (lowerCase.equals("abs")) {
                    z = false;
                    break;
                }
                break;
            case 96417:
                if (lowerCase.equals("add")) {
                    z = true;
                    break;
                }
                break;
            case 98695:
                if (lowerCase.equals("cos")) {
                    z = 7;
                    break;
                }
                break;
            case 99473:
                if (lowerCase.equals("div")) {
                    z = 8;
                    break;
                }
                break;
            case 100526:
                if (lowerCase.equals("elu")) {
                    z = 9;
                    break;
                }
                break;
            case 100893:
                if (lowerCase.equals("exp")) {
                    z = 11;
                    break;
                }
                break;
            case 107332:
                if (lowerCase.equals("log")) {
                    z = 16;
                    break;
                }
                break;
            case 107876:
                if (lowerCase.equals("max")) {
                    z = 18;
                    break;
                }
                break;
            case 108114:
                if (lowerCase.equals("min")) {
                    z = 19;
                    break;
                }
                break;
            case 108484:
                if (lowerCase.equals("mul")) {
                    z = 21;
                    break;
                }
                break;
            case 108944:
                if (lowerCase.equals("neg")) {
                    z = 22;
                    break;
                }
                break;
            case 111192:
                if (lowerCase.equals("pow")) {
                    z = 23;
                    break;
                }
                break;
            case 113880:
                if (lowerCase.equals("sin")) {
                    z = 29;
                    break;
                }
                break;
            case 114240:
                if (lowerCase.equals("sub")) {
                    z = 32;
                    break;
                }
                break;
            case 114593:
                if (lowerCase.equals("tan")) {
                    z = 34;
                    break;
                }
                break;
            case 2988422:
                if (lowerCase.equals("acos")) {
                    z = 2;
                    break;
                }
                break;
            case 3003607:
                if (lowerCase.equals("asin")) {
                    z = 3;
                    break;
                }
                break;
            case 3004320:
                if (lowerCase.equals("atan")) {
                    z = 4;
                    break;
                }
                break;
            case 3049733:
                if (lowerCase.equals("ceil")) {
                    z = 5;
                    break;
                }
                break;
            case 3318169:
                if (lowerCase.equals("less")) {
                    z = 15;
                    break;
                }
                break;
            case 3347397:
                if (lowerCase.equals("mean")) {
                    z = 20;
                    break;
                }
                break;
            case 3496700:
                if (lowerCase.equals("relu")) {
                    z = 26;
                    break;
                }
                break;
            case 3526491:
                if (lowerCase.equals("selu")) {
                    z = 27;
                    break;
                }
                break;
            case 3538208:
                if (lowerCase.equals("sqrt")) {
                    z = 30;
                    break;
                }
                break;
            case 3552487:
                if (lowerCase.equals("tanh")) {
                    z = 35;
                    break;
                }
                break;
            case 96757556:
                if (lowerCase.equals("equal")) {
                    z = 10;
                    break;
                }
                break;
            case 97526796:
                if (lowerCase.equals("floor")) {
                    z = 12;
                    break;
                }
                break;
            case 109399969:
                if (lowerCase.equals("shape")) {
                    z = 28;
                    break;
                }
                break;
            case 283601914:
                if (lowerCase.equals("greater")) {
                    z = 13;
                    break;
                }
                break;
            case 1097148750:
                if (lowerCase.equals("reshape")) {
                    z = 24;
                    break;
                }
                break;
            case 2088248974:
                if (lowerCase.equals("sigmoid")) {
                    z = 31;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new Map(name, nodeName, list, ScalarFunctions.abs());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.add());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.acos());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.asin());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.atan());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.ceil());
            case true:
                return new ConcatV2(name, nodeName, list);
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.cos());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.divide());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.elu());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.equal());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.exp());
            case true:
                return new Map(name, nodeName, list, ScalarFunctions.floor());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.greater());
            case true:
                return new Identity(name, nodeName, list);
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.less());
            case ModelParserConstants.MODEL /* 16 */:
                return new Map(name, nodeName, list, ScalarFunctions.log());
            case ModelParserConstants.TYPE /* 17 */:
                return new MatMul(name, nodeName, list);
            case ModelParserConstants.EXPRESSION_SL /* 18 */:
                return new Join(name, nodeName, list, ScalarFunctions.max());
            case ModelParserConstants.EXPRESSION_ML /* 19 */:
                return new Join(name, nodeName, list, ScalarFunctions.min());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.mean());
            case true:
                return new Join(name, nodeName, list, ScalarFunctions.multiply());
            case ModelParserConstants.BRACE_SL_LEVEL_3 /* 22 */:
                return new Map(name, nodeName, list, ScalarFunctions.neg());
            case ModelParserConstants.BRACE_SL_CONTENT /* 23 */:
                return new Join(name, nodeName, list, ScalarFunctions.pow());
            case ModelParserConstants.BRACE_ML_LEVEL_1 /* 24 */:
                return new Reshape(name, nodeName, list);
            case ModelParserConstants.BRACE_ML_LEVEL_2 /* 25 */:
                return new Map(name, nodeName, list, ScalarFunctions.reciprocal());
            case ModelParserConstants.BRACE_ML_LEVEL_3 /* 26 */:
                return new Map(name, nodeName, list, ScalarFunctions.relu());
            case ModelParserConstants.BRACE_ML_CONTENT /* 27 */:
                return new Map(name, nodeName, list, ScalarFunctions.selu());
            case ModelParserConstants.SEARCHLIB_SKIP /* 28 */:
                return new Shape(name, nodeName, list);
            case ModelParserConstants.CONSTANT /* 29 */:
                return new Map(name, nodeName, list, ScalarFunctions.sin());
            case ModelParserConstants.CONSTANTS /* 30 */:
                return new Map(name, nodeName, list, ScalarFunctions.sqrt());
            case ModelParserConstants.FILE /* 31 */:
                return new Map(name, nodeName, list, ScalarFunctions.sigmoid());
            case ModelParserConstants.URI /* 32 */:
                return new Join(name, nodeName, list, ScalarFunctions.subtract());
            case ModelParserConstants.IDENTIFIER /* 33 */:
                return new Map(name, nodeName, list, ScalarFunctions.square());
            case ModelParserConstants.CONTEXT /* 34 */:
                return new Map(name, nodeName, list, ScalarFunctions.tan());
            case ModelParserConstants.DOUBLE /* 35 */:
                return new Map(name, nodeName, list, ScalarFunctions.tanh());
            default:
                NoOp noOp = new NoOp(name, nodeName, list);
                noOp.warning("Operation '" + nodeProto.getOpType() + "' is currently not implemented");
                return noOp;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IntermediateGraph importGraph(String str, Onnx.ModelProto modelProto) {
        Onnx.GraphProto graph = modelProto.getGraph();
        IntermediateGraph intermediateGraph = new IntermediateGraph(str);
        importOperations(graph, intermediateGraph);
        verifyOutputTypes(graph, intermediateGraph);
        return intermediateGraph;
    }

    private static void importOperations(Onnx.GraphProto graphProto, IntermediateGraph intermediateGraph) {
        Iterator<Onnx.ValueInfoProto> it = graphProto.getOutputList().iterator();
        while (it.hasNext()) {
            importOperation(it.next().getName(), graphProto, intermediateGraph);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static IntermediateOperation importOperation(String str, Onnx.GraphProto graphProto, IntermediateGraph intermediateGraph) {
        IntermediateOperation mapOperation;
        if (intermediateGraph.alreadyImported(str)) {
            return intermediateGraph.get(str);
        }
        if (isArgumentTensor(str, graphProto)) {
            Onnx.ValueInfoProto argumentTensor = getArgumentTensor(str, graphProto);
            if (argumentTensor == null) {
                throw new IllegalArgumentException("Could not find argument tensor '" + str + "'");
            }
            mapOperation = new Argument(intermediateGraph.name(), argumentTensor.getName(), TypeConverter.typeFrom(argumentTensor.getType()));
            intermediateGraph.inputs(intermediateGraph.defaultSignature()).put(IntermediateOperation.namePartOf(str), mapOperation.vespaName());
        } else if (isConstantTensor(str, graphProto)) {
            Onnx.TensorProto constantTensor = getConstantTensor(str, graphProto);
            mapOperation = new Constant(intermediateGraph.name(), str, TypeConverter.typeFrom(constantTensor));
            mapOperation.setConstantValueFunction(orderedTensorType -> {
                return new TensorValue(TensorConverter.toVespaTensor(constantTensor, orderedTensorType));
            });
        } else {
            Onnx.NodeProto nodeFromGraph = getNodeFromGraph(str, graphProto);
            mapOperation = mapOperation(nodeFromGraph, importOperationInputs(nodeFromGraph, graphProto, intermediateGraph), intermediateGraph);
            if (isOutputNode(str, graphProto)) {
                intermediateGraph.outputs(intermediateGraph.defaultSignature()).put(IntermediateOperation.namePartOf(str), mapOperation.vespaName());
            }
        }
        intermediateGraph.put(mapOperation.vespaName(), mapOperation);
        return mapOperation;
    }

    private static boolean isArgumentTensor(String str, Onnx.GraphProto graphProto) {
        return getArgumentTensor(str, graphProto) != null && getConstantTensor(str, graphProto) == null;
    }

    private static boolean isConstantTensor(String str, Onnx.GraphProto graphProto) {
        return (getArgumentTensor(str, graphProto) == null || getConstantTensor(str, graphProto) == null) ? false : true;
    }

    private static Onnx.ValueInfoProto getArgumentTensor(String str, Onnx.GraphProto graphProto) {
        for (Onnx.ValueInfoProto valueInfoProto : graphProto.getInputList()) {
            if (valueInfoProto.getName().equals(str)) {
                return valueInfoProto;
            }
        }
        return null;
    }

    private static Onnx.TensorProto getConstantTensor(String str, Onnx.GraphProto graphProto) {
        for (Onnx.TensorProto tensorProto : graphProto.getInitializerList()) {
            if (tensorProto.getName().equals(str)) {
                return tensorProto;
            }
        }
        return null;
    }

    private static boolean isOutputNode(String str, Onnx.GraphProto graphProto) {
        return getOutputNode(str, graphProto) != null;
    }

    private static Onnx.ValueInfoProto getOutputNode(String str, Onnx.GraphProto graphProto) {
        Iterator<Onnx.ValueInfoProto> it = graphProto.getOutputList().iterator();
        while (it.hasNext()) {
            Onnx.ValueInfoProto next = it.next();
            if (!next.getName().equals(str) && !IntermediateOperation.namePartOf(next.getName()).equals(str)) {
            }
            return next;
        }
        return null;
    }

    private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto nodeProto, Onnx.GraphProto graphProto, IntermediateGraph intermediateGraph) {
        return (List) nodeProto.mo213getInputList().stream().map(str -> {
            return importOperation(str, graphProto, intermediateGraph);
        }).collect(Collectors.toList());
    }

    private static void verifyOutputTypes(Onnx.GraphProto graphProto, IntermediateGraph intermediateGraph) {
        for (String str : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
            IntermediateOperation intermediateOperation = intermediateGraph.get(str);
            Onnx.ValueInfoProto outputNode = getOutputNode(str, graphProto);
            TypeConverter.verifyType(outputNode.getType(), intermediateOperation.type().orElseThrow(() -> {
                return new IllegalArgumentException("Output of '" + str + "' has no type.");
            }));
        }
    }

    private static Onnx.NodeProto getNodeFromGraph(String str, Onnx.GraphProto graphProto) {
        Optional<Onnx.NodeProto> nodeFromGraphNames;
        if (str.contains(":")) {
            nodeFromGraphNames = getNodeFromGraphOutputs(str, graphProto);
        } else {
            nodeFromGraphNames = getNodeFromGraphNames(str, graphProto);
            if (nodeFromGraphNames.isEmpty()) {
                nodeFromGraphNames = getNodeFromGraphOutputs(str, graphProto);
            }
        }
        return nodeFromGraphNames.orElseThrow(() -> {
            return new IllegalArgumentException("Node '" + str + "' not found in ONNX graph");
        });
    }

    private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String str, Onnx.GraphProto graphProto) {
        for (Onnx.NodeProto nodeProto : graphProto.getNodeList()) {
            Iterator it = nodeProto.mo212getOutputList().iterator();
            while (it.hasNext()) {
                if (((String) it.next()).equals(str)) {
                    return Optional.of(nodeProto);
                }
            }
        }
        return Optional.empty();
    }

    private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String str, Onnx.GraphProto graphProto) {
        for (Onnx.NodeProto nodeProto : graphProto.getNodeList()) {
            if (nodeProto.getName().equals(str)) {
                return Optional.of(nodeProto);
            }
        }
        return Optional.empty();
    }

    private static String getNodeName(Onnx.NodeProto nodeProto) {
        String name = nodeProto.getName();
        if (name.length() > 0) {
            return name;
        }
        if (nodeProto.getOutputCount() == 1) {
            return nodeProto.getOutput(0);
        }
        throw new IllegalArgumentException("Unable to find a suitable name for node '" + nodeProto.toString() + "'. Either no explicit name given or no single output name.");
    }
}
