package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Softmax.class */
public class Softmax extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;

    /* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Softmax$SoftmaxPartialOperation.class */
    private class SoftmaxPartialOperation extends IntermediateOperation {
        private SoftmaxPartialOperation(String str, String str2, List<IntermediateOperation> list) {
            super(str, str2 + "_partial", list != null ? list : Collections.emptyList());
        }

        @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
        protected OrderedTensorType lazyGetType() {
            if (!allInputTypesPresent(1)) {
                return null;
            }
            this.inputs.get(0).exportAsRankingFunction = true;
            this.exportAsRankingFunction = true;
            return this.inputs.get(0).type().get();
        }

        @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
        protected TensorFunction lazyGetFunction() {
            if (!allInputFunctionsPresent(1)) {
                return null;
            }
            List<String> reduceDimensions = Softmax.this.reduceDimensions();
            TensorFunction tensorFunction = this.inputs.get(0).function().get();
            return new com.yahoo.tensor.functions.Map(new com.yahoo.tensor.functions.Join(tensorFunction, new com.yahoo.tensor.functions.Reduce(tensorFunction, Reduce.Aggregator.max, reduceDimensions), ScalarFunctions.subtract()), ScalarFunctions.exp());
        }

        @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
        public SoftmaxPartialOperation withInputs(List<IntermediateOperation> list) {
            return new SoftmaxPartialOperation(modelName(), name(), list);
        }

        @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
        public String operationName() {
            return "SoftMaxPartial";
        }

        @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
        public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
            return withInputs((List<IntermediateOperation>) list);
        }
    }

    public Softmax(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
        insert(new SoftmaxPartialOperation(str, str2, null), 0);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (allInputTypesPresent(1)) {
            return this.inputs.get(0).type().get();
        }
        return null;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputFunctionsPresent(1)) {
            return null;
        }
        List<String> reduceDimensions = reduceDimensions();
        TensorFunction tensorFunction = this.inputs.get(0).function().get();
        return new com.yahoo.tensor.functions.Join(tensorFunction, new com.yahoo.tensor.functions.Reduce(tensorFunction, Reduce.Aggregator.sum, reduceDimensions), ScalarFunctions.divide());
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Softmax withInputs(List<IntermediateOperation> list) {
        return new Softmax(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "SoftMax";
    }

    private List<String> reduceDimensions() {
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        int i = orderedTensorType.rank() == 1 ? 0 : 1;
        if (this.attributeMap.get("axis").isPresent()) {
            i = (int) this.attributeMap.get("axis").get().asDouble();
        }
        if (i < 0) {
            i = orderedTensorType.rank() + i;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = i; i2 < orderedTensorType.rank(); i2++) {
            arrayList.add(orderedTensorType.dimensions().get(i2).name());
        }
        return arrayList;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
