package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Gather.class */
public class Gather extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private int axis;

    public Gather(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
        this.axis = (int) this.attributeMap.get("axis").orElse(DoubleValue.zero).asDouble();
        if (this.axis < 0) {
            this.axis = orderedTensorType.rank() + this.axis;
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        for (int i = 0; i < this.axis; i++) {
            addDimension(i, ((Long) orderedTensorType.dimensions().get(i).size().orElse(-1L)).longValue(), builder);
        }
        for (int i2 = 0; i2 < orderedTensorType2.rank(); i2++) {
            addDimension(i2 + this.axis, ((Long) orderedTensorType2.dimensions().get(i2).size().orElse(-1L)).longValue(), builder);
        }
        for (int i3 = this.axis + 1; i3 < orderedTensorType.rank(); i3++) {
            addDimension(i3 + orderedTensorType2.rank(), ((Long) orderedTensorType.dimensions().get(i3).size().orElse(-1L)).longValue(), builder);
        }
        this.inputs.get(0).exportAsRankingFunction = true;
        this.inputs.get(1).exportAsRankingFunction = true;
        return builder.build();
    }

    private void addDimension(int i, long j, OrderedTensorType.Builder builder) {
        builder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), Integer.valueOf(i)), j));
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputFunctionsPresent(2)) {
            return null;
        }
        IntermediateOperation intermediateOperation = this.inputs.get(0);
        IntermediateOperation intermediateOperation2 = this.inputs.get(1);
        OrderedTensorType orderedTensorType = intermediateOperation.type().get();
        OrderedTensorType orderedTensorType2 = intermediateOperation2.type().get();
        String rankingExpressionFunctionName = intermediateOperation.rankingExpressionFunctionName();
        String rankingExpressionFunctionName2 = intermediateOperation2.rankingExpressionFunctionName();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.axis; i++) {
            addSliceDimension(arrayList, orderedTensorType.dimensions().get(i).name(), i);
        }
        if (orderedTensorType2.rank() == 0 && intermediateOperation2.isConstant()) {
            double asDouble = intermediateOperation2.getConstantValue().get().asDouble();
            ExpressionNode constantNode = new ConstantNode(new DoubleValue(asDouble));
            if (asDouble < 0.0d) {
                constantNode = new EmbracedNode(new ArithmeticNode(constantNode, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(((Long) orderedTensorType.dimensions().get(this.axis).size().get()).longValue()))));
            }
            addSliceDimension(arrayList, orderedTensorType.dimensions().get(this.axis).name(), constantNode);
        } else {
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < orderedTensorType2.rank(); i2++) {
                addSliceDimension(arrayList2, orderedTensorType2.dimensions().get(i2).name(), this.axis + i2);
            }
            addSliceDimension(arrayList, orderedTensorType.dimensions().get(this.axis).name(), createIndexExpression(orderedTensorType, createSliceExpression(arrayList2, rankingExpressionFunctionName2)));
        }
        for (int i3 = this.axis + 1; i3 < orderedTensorType.rank(); i3++) {
            addSliceDimension(arrayList, orderedTensorType.dimensions().get(i3).name(), (i3 + orderedTensorType2.rank()) - 1);
        }
        return Generate.bound(this.type.type(), TensorFunctionNode.wrapScalar(createSliceExpression(arrayList, rankingExpressionFunctionName)));
    }

    private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> list, String str) {
        TensorFunctionNode.ExpressionTensorFunction expressionTensorFunction = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(str));
        return list.isEmpty() ? new TensorFunctionNode(expressionTensorFunction) : new TensorFunctionNode(new com.yahoo.tensor.functions.Slice(expressionTensorFunction, list));
    }

    private ExpressionNode createIndexExpression(OrderedTensorType orderedTensorType, ExpressionNode expressionNode) {
        ConstantNode constantNode = new ConstantNode(new DoubleValue(((Long) orderedTensorType.dimensions().get(this.axis).size().get()).longValue()));
        return new ArithmeticNode(new EmbracedNode(new ArithmeticNode(expressionNode, ArithmeticOperator.PLUS, constantNode)), ArithmeticOperator.MODULO, constantNode);
    }

    private void addSliceDimension(List<Slice.DimensionValue<Reference>> list, String str, ExpressionNode expressionNode) {
        list.add(new Slice.DimensionValue<>(Optional.of(str), TensorFunctionNode.wrapScalar(new EmbracedNode(expressionNode))));
    }

    private void addSliceDimension(List<Slice.DimensionValue<Reference>> list, String str, int i) {
        addSliceDimension(list, str, (ExpressionNode) new ReferenceNode(this.type.dimensions().get(i).name()));
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (allInputTypesPresent(2)) {
            for (int i = 0; i < this.type.dimensions().size(); i++) {
                dimensionRenamer.addDimension(this.type.dimensions().get(i).name());
                for (int i2 = i + 1; i2 < this.type.dimensions().size(); i2++) {
                    dimensionRenamer.addConstraint(this.type.dimensions().get(i).name(), this.type.dimensions().get(i2).name(), DimensionRenamer.Constraint.lessThan(), this);
                }
            }
            OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
            OrderedTensorType orderedTensorType2 = this.inputs.get(1).type().get();
            for (int i3 = 0; i3 < this.axis; i3++) {
                dimensionRenamer.addConstraint(this.type.dimensions().get(i3).name(), orderedTensorType.dimensions().get(i3).name(), DimensionRenamer.Constraint.equal(), this);
            }
            for (int i4 = 0; i4 < orderedTensorType2.rank(); i4++) {
                dimensionRenamer.addConstraint(this.type.dimensions().get(i4 + this.axis).name(), orderedTensorType2.dimensions().get(i4).name(), DimensionRenamer.Constraint.equal(), this);
            }
            for (int i5 = this.axis + 1; i5 < orderedTensorType.rank(); i5++) {
                dimensionRenamer.addConstraint(this.type.dimensions().get((i5 + orderedTensorType2.rank()) - 1).name(), orderedTensorType.dimensions().get(i5).name(), DimensionRenamer.Constraint.equal(), this);
            }
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Gather withInputs(List<IntermediateOperation> list) {
        return new Gather(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Gather";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
