package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.function.DoubleBinaryOperator;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Join.class */
public class Join extends IntermediateOperation {
    private final DoubleBinaryOperator operator;

    public Join(String str, String str2, List<IntermediateOperation> list, DoubleBinaryOperator doubleBinaryOperator) {
        super(str, str2, list);
        this.operator = doubleBinaryOperator;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        OrderedTensorType orderedTensorType = largestInput().type().get();
        OrderedTensorType orderedTensorType2 = smallestInput().type().get();
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        int rank = orderedTensorType.rank() - orderedTensorType2.rank();
        for (int i = 0; i < orderedTensorType.rank(); i++) {
            TensorType.Dimension dimension = orderedTensorType.dimensions().get(i);
            long longValue = ((Long) dimension.size().orElse(-1L)).longValue();
            if (i - rank >= 0) {
                longValue = Math.max(longValue, ((Long) orderedTensorType2.dimensions().get(i - rank).size().orElse(-1L)).longValue());
            }
            if (dimension.type() == TensorType.Dimension.Type.indexedBound) {
                builder.add(TensorType.Dimension.indexed(dimension.name(), longValue));
            } else if (dimension.type() == TensorType.Dimension.Type.indexedUnbound) {
                builder.add(TensorType.Dimension.indexed(dimension.name()));
            } else if (dimension.type() == TensorType.Dimension.Type.mapped) {
                builder.add(TensorType.Dimension.mapped(dimension.name()));
            }
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputTypesPresent(2) || !allInputFunctionsPresent(2)) {
            return null;
        }
        IntermediateOperation largestInput = largestInput();
        IntermediateOperation smallestInput = smallestInput();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int rank = largestInput.type().get().rank() - smallestInput.type().get().rank();
        for (int i = 0; i < smallestInput.type().get().rank(); i++) {
            TensorType.Dimension dimension = smallestInput.type().get().dimensions().get(i);
            TensorType.Dimension dimension2 = largestInput.type().get().dimensions().get(i + rank);
            long longValue = ((Long) dimension.size().orElse(-1L)).longValue();
            long longValue2 = ((Long) dimension2.size().orElse(-1L)).longValue();
            if (longValue == 1 && longValue2 != 1) {
                arrayList2.add(dimension.name());
            }
            if (longValue2 == 1 && longValue != 1) {
                arrayList.add(dimension.name());
            }
        }
        Reduce reduce = (TensorFunction) largestInput.function().get();
        if (arrayList.size() > 0) {
            reduce = new Reduce(largestInput.function().get(), Reduce.Aggregator.sum, arrayList);
        }
        Reduce reduce2 = (TensorFunction) smallestInput.function().get();
        if (arrayList2.size() > 0) {
            reduce2 = new Reduce(smallestInput.function().get(), Reduce.Aggregator.sum, arrayList2);
        }
        return new com.yahoo.tensor.functions.Join(reduce, reduce2, this.operator);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (allInputTypesPresent(2)) {
            OrderedTensorType orderedTensorType = largestInput().type().get();
            OrderedTensorType orderedTensorType2 = smallestInput().type().get();
            int rank = orderedTensorType.rank() - orderedTensorType2.rank();
            for (int i = 0; i < orderedTensorType2.rank(); i++) {
                dimensionRenamer.addConstraint(orderedTensorType.dimensions().get(i + rank).name(), orderedTensorType2.dimensions().get(i).name(), DimensionRenamer.Constraint.equal(false), this);
            }
        }
    }

    private IntermediateOperation largestInput() {
        return this.inputs.get(0).type().get().rank() >= this.inputs.get(1).type().get().rank() ? this.inputs.get(0) : this.inputs.get(1);
    }

    private IntermediateOperation smallestInput() {
        return this.inputs.get(0).type().get().rank() < this.inputs.get(1).type().get().rank() ? this.inputs.get(0) : this.inputs.get(1);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Join withInputs(List<IntermediateOperation> list) {
        return new Join(modelName(), name(), list, this.operator);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Join";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
