package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Gemm.class */
public class Gemm extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private final float alpha;
    private final float beta;
    private final int transposeA;
    private final int transposeB;
    private static final DoubleValue zero = DoubleValue.frozen(0.0d);
    private static final DoubleValue one = DoubleValue.frozen(1.0d);

    public Gemm(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
        this.alpha = (float) attributeMap.get("alpha").orElse(one).asDouble();
        this.beta = (float) attributeMap.get("beta").orElse(one).asDouble();
        this.transposeA = (int) attributeMap.get("transA").orElse(zero).asDouble();
        this.transposeB = (int) attributeMap.get("transB").orElse(zero).asDouble();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!check2or3InputsPresent()) {
            return null;
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        TensorType.Dimension dimension = this.inputs.get(0).type().get().dimensions().get(this.transposeA);
        TensorType.Dimension dimension2 = this.inputs.get(1).type().get().dimensions().get(1 - this.transposeB);
        builder.add(dimension);
        builder.add(dimension2);
        OrderedTensorType build = builder.build();
        if (this.inputs.size() == 3) {
            List<TensorType.Dimension> dimensions = this.inputs.get(2).type().get().dimensions();
            if (dimensions.size() == 2) {
                TensorType.Dimension dimension3 = dimensions.get(0);
                TensorType.Dimension dimension4 = dimensions.get(1);
                if (!((Long) dimension.size().get()).equals(dimension3.size().get()) && ((Long) dimension3.size().get()).longValue() != 1) {
                    throw new IllegalArgumentException("GEMM: type of optional input C " + this.inputs.get(2).type().get() + " is not compatible or not broadcastable to " + build.type());
                }
                if (!((Long) dimension2.size().get()).equals(dimension4.size().get()) && ((Long) dimension4.size().get()).longValue() != 1) {
                    throw new IllegalArgumentException("GEMM: type of optional input C " + this.inputs.get(2).type().get() + " is not compatible or not broadcastable to " + build.type());
                }
            } else {
                if (dimensions.size() != 1) {
                    throw new IllegalArgumentException("GEMM: optional input C has no dimensions.");
                }
                TensorType.Dimension dimension5 = dimensions.get(0);
                if (!((Long) dimension2.size().get()).equals(dimension5.size().get()) && ((Long) dimension5.size().get()).longValue() != 1) {
                    throw new IllegalArgumentException("GEMM: type of optional input C " + this.inputs.get(2).type().get() + " is not compatible or not broadcastable to " + build.type());
                }
            }
        }
        return build;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!check2or3InputsPresent()) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
        if (orderedTensorType.type().rank() != 2 || orderedTensorType2.type().rank() != 2) {
            throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2");
        }
        Optional<TensorFunction<Reference>> function = this.inputs.get(0).function();
        Optional<TensorFunction<Reference>> function2 = this.inputs.get(1).function();
        if (function.isEmpty() || function2.isEmpty()) {
            return null;
        }
        TensorFunctionNode.ExpressionTensorFunction expressionTensorFunction = new TensorFunctionNode.ExpressionTensorFunction(new ArithmeticNode(new TensorFunctionNode(new Matmul(function.get(), function2.get(), orderedTensorType.dimensions().get(1 - this.transposeA).name())), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(this.alpha))));
        return this.inputs.size() == 3 ? new com.yahoo.tensor.functions.Join(expressionTensorFunction, new TensorFunctionNode.ExpressionTensorFunction(new ArithmeticNode(new TensorFunctionNode(this.inputs.get(2).function().get()), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(this.beta)))), ScalarFunctions.add()) : expressionTensorFunction;
    }

    private boolean check2or3InputsPresent() {
        if (this.inputs.size() == 2 || this.inputs.size() == 3) {
            return this.inputs.stream().map((v0) -> {
                return v0.type();
            }).allMatch((v0) -> {
                return v0.isPresent();
            });
        }
        throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + this.name + "', got " + this.inputs.size());
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (check2or3InputsPresent()) {
            List<TensorType.Dimension> dimensions = this.inputs.get(0).type().get().dimensions();
            List<TensorType.Dimension> dimensions2 = this.inputs.get(1).type().get().dimensions();
            assertTwoDimensions(dimensions, this.inputs.get(0), "first argument");
            assertTwoDimensions(dimensions2, this.inputs.get(1), "second argument");
            String name = dimensions.get(this.transposeA).name();
            String name2 = dimensions.get(1 - this.transposeA).name();
            String name3 = dimensions2.get(this.transposeB).name();
            String name4 = dimensions2.get(1 - this.transposeB).name();
            dimensionRenamer.addConstraint(name2, name3, DimensionRenamer.Constraint.equal(false), this);
            dimensionRenamer.addConstraint(name, name4, DimensionRenamer.Constraint.lessThan(false), this);
            if (this.inputs.size() == 3) {
                List<TensorType.Dimension> dimensions3 = this.inputs.get(2).type().get().dimensions();
                if (dimensions3.size() == 2) {
                    String name5 = dimensions3.get(0).name();
                    String name6 = dimensions3.get(1).name();
                    dimensionRenamer.addConstraint(name, name5, DimensionRenamer.Constraint.equal(false), this);
                    dimensionRenamer.addConstraint(name4, name6, DimensionRenamer.Constraint.equal(false), this);
                } else if (dimensions3.size() == 1) {
                    dimensionRenamer.addConstraint(name4, dimensions3.get(0).name(), DimensionRenamer.Constraint.equal(false), this);
                }
            }
            dimensionRenamer.addConstraint(name, name2, DimensionRenamer.Constraint.lessThan(true), this);
            dimensionRenamer.addConstraint(name3, name4, DimensionRenamer.Constraint.greaterThan(true), this);
        }
    }

    private void assertTwoDimensions(List<TensorType.Dimension> list, IntermediateOperation intermediateOperation, String str) {
        if (list.size() < 2) {
            throw new IllegalArgumentException("Expected 2 dimensions in the " + str + " to " + this + " but got just " + list + " from\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(intermediateOperation.toFullString()));
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Gemm withInputs(List<IntermediateOperation> list) {
        return new Gemm(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Gemm";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
