package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/MatMul.class */
public class MatMul extends IntermediateOperation {
    public MatMul(String str, String str2, List<IntermediateOperation> list) {
        super(str, str2, list);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        builder.add(this.inputs.get(0).type().get().dimensions().get(0));
        builder.add(this.inputs.get(1).type().get().dimensions().get(1));
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
        if (orderedTensorType.type().rank() < 2 || orderedTensorType2.type().rank() < 2) {
            throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
        }
        if (orderedTensorType.type().rank() != orderedTensorType2.type().rank()) {
            throw new IllegalArgumentException("Tensors in matmul must have the same rank");
        }
        Optional<TensorFunction> function = this.inputs.get(0).function();
        Optional<TensorFunction> function2 = this.inputs.get(1).function();
        if (function.isPresent() && function2.isPresent()) {
            return new Matmul(function.get(), function2.get(), orderedTensorType.dimensions().get(1).name());
        }
        return null;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (allInputTypesPresent(2)) {
            List<TensorType.Dimension> dimensions = this.inputs.get(0).type().get().dimensions();
            List<TensorType.Dimension> dimensions2 = this.inputs.get(1).type().get().dimensions();
            assertTwoDimensions(dimensions, this.inputs.get(0), "first argument");
            assertTwoDimensions(dimensions2, this.inputs.get(1), "second argument");
            String name = dimensions.get(0).name();
            String name2 = dimensions.get(1).name();
            String name3 = dimensions2.get(0).name();
            String name4 = dimensions2.get(1).name();
            dimensionRenamer.addConstraint(name2, name3, DimensionRenamer.Constraint.equal(false), this);
            dimensionRenamer.addConstraint(name, name4, DimensionRenamer.Constraint.lessThan(false), this);
            dimensionRenamer.addConstraint(name, name2, DimensionRenamer.Constraint.lessThan(true), this);
            dimensionRenamer.addConstraint(name3, name4, DimensionRenamer.Constraint.greaterThan(true), this);
        }
    }

    private void assertTwoDimensions(List<TensorType.Dimension> list, IntermediateOperation intermediateOperation, String str) {
        if (list.size() < 2) {
            throw new IllegalArgumentException("Expected 2 dimensions in the " + str + " to " + this + " but got just " + list + " from\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(intermediateOperation.toFullString()));
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public MatMul withInputs(List<IntermediateOperation> list) {
        return new MatMul(modelName(), name(), list);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "MatMul";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
