package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Sum.class */
public class Sum extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private List<String> reduceDimensions;

    public Sum(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        IntermediateOperation intermediateOperation = this.inputs.get(1);
        if (!intermediateOperation.getConstantValue().isPresent()) {
            throw new IllegalArgumentException("Sum in " + this.name + ": Reduction indices must be a constant.");
        }
        Tensor asTensor = intermediateOperation.getConstantValue().get().asTensor();
        this.reduceDimensions = new ArrayList();
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        Iterator cellIterator = asTensor.cellIterator();
        while (cellIterator.hasNext()) {
            int intValue = ((Tensor.Cell) cellIterator.next()).getValue().intValue();
            if (intValue < 0) {
                intValue = orderedTensorType.dimensions().size() - intValue;
            }
            this.reduceDimensions.add(orderedTensorType.dimensions().get(intValue).name());
        }
        return reducedType(orderedTensorType, shouldKeepDimensions());
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputTypesPresent(2)) {
            return null;
        }
        TensorFunction reduce = new Reduce(this.inputs.get(0).function().get(), Reduce.Aggregator.sum, this.reduceDimensions);
        if (shouldKeepDimensions()) {
            TensorType.Builder builder = new TensorType.Builder(resultValueType());
            Iterator<String> it = this.reduceDimensions.iterator();
            while (it.hasNext()) {
                builder.indexed(it.next(), 1L);
            }
            TensorType build = builder.build();
            reduce = new com.yahoo.tensor.functions.Join(reduce, new Generate(build, new GeneratorLambdaFunctionNode(build, new ConstantNode(new DoubleValue(1.0d))).asLongListToDoubleOperator()), ScalarFunctions.multiply());
        }
        return reduce;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void renameDimensions(DimensionRenamer dimensionRenamer) {
        super.renameDimensions(dimensionRenamer);
        ArrayList arrayList = new ArrayList(this.reduceDimensions.size());
        Iterator<String> it = this.reduceDimensions.iterator();
        while (it.hasNext()) {
            Optional<String> dimensionNameOf = dimensionRenamer.dimensionNameOf(it.next());
            if (!dimensionNameOf.isPresent()) {
                return;
            } else {
                arrayList.add(dimensionNameOf.get());
            }
        }
        this.reduceDimensions = arrayList;
    }

    private boolean shouldKeepDimensions() {
        Optional<Value> optional = this.attributeMap.get("keep_dims");
        return optional.isPresent() && optional.get().asBoolean();
    }

    private OrderedTensorType reducedType(OrderedTensorType orderedTensorType, boolean z) {
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        for (TensorType.Dimension dimension : orderedTensorType.type().dimensions()) {
            if (!this.reduceDimensions.contains(dimension.name())) {
                builder.add(dimension);
            } else if (z) {
                builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
            }
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Sum withInputs(List<IntermediateOperation> list) {
        return new Sum(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Sum";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
