package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/MatMul.class */
public class MatMul extends IntermediateOperation {
    public MatMul(String str, String str2, List<IntermediateOperation> list) {
        super(str, str2, list);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
        if (orderedTensorType.type().rank() < 1 || orderedTensorType2.type().rank() < 1) {
            throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        OrderedTensorType orderedTensorType3 = orderedTensorType.rank() >= orderedTensorType2.rank() ? orderedTensorType : orderedTensorType2;
        OrderedTensorType orderedTensorType4 = orderedTensorType.rank() >= orderedTensorType2.rank() ? orderedTensorType2 : orderedTensorType;
        for (int i = 0; i < orderedTensorType3.rank() - 2; i++) {
            TensorType.Dimension dimension = orderedTensorType3.dimensions().get(i);
            int rank = (orderedTensorType4.rank() - orderedTensorType3.rank()) + i;
            if (rank >= 0 && ((Long) orderedTensorType4.dimensions().get(rank).size().get()).longValue() > ((Long) dimension.size().get()).longValue()) {
                dimension = orderedTensorType4.dimensions().get(rank);
            }
            builder.add(dimension);
        }
        if (orderedTensorType.rank() >= 2) {
            builder.add(orderedTensorType.dimensions().get(orderedTensorType.rank() - 2));
        }
        if (orderedTensorType2.rank() >= 2) {
            builder.add(orderedTensorType2.dimensions().get(orderedTensorType2.rank() - 1));
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputTypesPresent(2) || !allInputFunctionsPresent(2)) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
        return new com.yahoo.tensor.functions.Reduce(new com.yahoo.tensor.functions.Join(handleBroadcasting(this.inputs.get(0).function().get(), orderedTensorType, orderedTensorType2), handleBroadcasting(this.inputs.get(1).function().get(), orderedTensorType2, orderedTensorType), ScalarFunctions.multiply()), Reduce.Aggregator.sum, orderedTensorType.dimensions().get(orderedTensorType.rank() - 1).name());
    }

    private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType orderedTensorType, OrderedTensorType orderedTensorType2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < orderedTensorType.rank() - 2; i++) {
            long longValue = ((Long) orderedTensorType.dimensions().get(i).size().get()).longValue();
            String str = orderedTensorType.dimensionNames().get(i);
            int rank = (orderedTensorType2.rank() - orderedTensorType.rank()) + i;
            if (rank >= 0 && ((Long) orderedTensorType2.dimensions().get(rank).size().get()).longValue() > longValue && longValue == 1) {
                arrayList.add(new Slice.DimensionValue(Optional.of(str), TensorFunctionNode.wrapScalar(new EmbracedNode(new ConstantNode(DoubleValue.zero)))));
            }
        }
        return arrayList.size() == 0 ? tensorFunction : new com.yahoo.tensor.functions.Slice(tensorFunction, arrayList);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (allInputTypesPresent(2)) {
            OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
            OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
            String name = orderedTensorType.dimensions().get(orderedTensorType.rank() - 1).name();
            String name2 = orderedTensorType2.dimensions().get(orderedTensorType2.rank() - 1).name();
            String name3 = orderedTensorType.dimensions().get(Math.max(0, orderedTensorType.rank() - 2)).name();
            String name4 = orderedTensorType2.dimensions().get(Math.max(0, orderedTensorType2.rank() - 2)).name();
            dimensionRenamer.addConstraint(name, name4, DimensionRenamer.Constraint.equal(false), this);
            if (orderedTensorType.rank() >= 2 && orderedTensorType2.rank() >= 2) {
                dimensionRenamer.addConstraint(name3, name2, DimensionRenamer.Constraint.lessThan(false), this);
            }
            if (orderedTensorType.rank() >= 2) {
                dimensionRenamer.addConstraint(name3, name, DimensionRenamer.Constraint.lessThan(true), this);
            }
            if (orderedTensorType2.rank() >= 2) {
                dimensionRenamer.addConstraint(name4, name2, DimensionRenamer.Constraint.greaterThan(true), this);
            }
            for (int i = 0; i < orderedTensorType.rank() - 2; i++) {
                String str = orderedTensorType.dimensionNames().get(i);
                for (int i2 = i + 1; i2 < orderedTensorType.rank(); i2++) {
                    dimensionRenamer.addConstraint(str, orderedTensorType.dimensionNames().get(i2), DimensionRenamer.Constraint.lessThan(false), this);
                }
                for (int rank = orderedTensorType2.rank() - 2; rank < orderedTensorType2.rank(); rank++) {
                    if (rank >= 0) {
                        dimensionRenamer.addConstraint(str, orderedTensorType2.dimensionNames().get(rank), DimensionRenamer.Constraint.notEqual(false), this);
                    }
                }
                int rank2 = (orderedTensorType2.rank() - orderedTensorType.rank()) + i;
                if (rank2 >= 0) {
                    dimensionRenamer.addConstraint(str, orderedTensorType2.dimensionNames().get(rank2), DimensionRenamer.Constraint.equal(false), this);
                }
            }
            for (int i3 = 0; i3 < orderedTensorType2.rank() - 2; i3++) {
                String str2 = orderedTensorType2.dimensionNames().get(i3);
                for (int i4 = i3 + 1; i4 < orderedTensorType2.rank(); i4++) {
                    dimensionRenamer.addConstraint(str2, orderedTensorType2.dimensionNames().get(i4), DimensionRenamer.Constraint.lessThan(false), this);
                }
                for (int rank3 = orderedTensorType.rank() - 2; rank3 < orderedTensorType.rank(); rank3++) {
                    if (rank3 >= 0) {
                        dimensionRenamer.addConstraint(str2, orderedTensorType.dimensionNames().get(rank3), DimensionRenamer.Constraint.notEqual(false), this);
                    }
                }
                int rank4 = (orderedTensorType.rank() - orderedTensorType2.rank()) + i3;
                if (rank4 >= 0) {
                    dimensionRenamer.addConstraint(str2, orderedTensorType.dimensionNames().get(rank4), DimensionRenamer.Constraint.equal(false), this);
                }
            }
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public MatMul withInputs(List<IntermediateOperation> list) {
        return new MatMul(modelName(), name(), list);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "MatMul";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
