package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Unsqueeze.class */
public class Unsqueeze extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private List<String> expandDimensions;

    public Unsqueeze(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
        if (attributeMap.getList("axes").isEmpty()) {
            throw new IllegalArgumentException("Unsqueeze in " + this.name + ": Required attribute 'axes' is missing.");
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(1)) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        Set set = (Set) this.attributeMap.getList("axes").get().stream().map(value -> {
            return Integer.valueOf((int) value.asDouble());
        }).collect(Collectors.toSet());
        int rank = orderedTensorType.rank() + set.size();
        Set set2 = (Set) set.stream().map(num -> {
            return Integer.valueOf(num.intValue() < 0 ? rank + num.intValue() : num.intValue());
        }).collect(Collectors.toSet());
        this.expandDimensions = new ArrayList();
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        int i = 0;
        for (int i2 = 0; i2 < rank; i2++) {
            if (set2.contains(Integer.valueOf(i2))) {
                addDimension(i2, builder);
            } else {
                builder.add(orderedTensorType.dimensions().get(i));
                i++;
            }
        }
        return builder.build();
    }

    private void addDimension(int i, OrderedTensorType.Builder builder) {
        String format = String.format("%s_%d", vespaName(), Integer.valueOf(i));
        this.expandDimensions.add(format);
        builder.add(TensorType.Dimension.indexed(format, 1L));
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction lazyGetFunction() {
        if (!allInputFunctionsPresent(1)) {
            return null;
        }
        TensorType.Builder builder = new TensorType.Builder(resultValueType());
        Iterator<String> it = this.expandDimensions.iterator();
        while (it.hasNext()) {
            builder.indexed(it.next(), 1L);
        }
        TensorType build = builder.build();
        return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), new Generate(build, new GeneratorLambdaFunctionNode(build, new ConstantNode(new DoubleValue(1.0d))).asLongListToDoubleOperator()), ScalarFunctions.multiply());
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        addConstraintsFrom(this.type, dimensionRenamer);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void renameDimensions(DimensionRenamer dimensionRenamer) {
        super.renameDimensions(dimensionRenamer);
        ArrayList arrayList = new ArrayList(this.expandDimensions.size());
        Iterator<String> it = this.expandDimensions.iterator();
        while (it.hasNext()) {
            Optional<String> dimensionNameOf = dimensionRenamer.dimensionNameOf(it.next());
            if (dimensionNameOf.isEmpty()) {
                return;
            } else {
                arrayList.add(dimensionNameOf.get());
            }
        }
        this.expandDimensions = arrayList;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Unsqueeze withInputs(List<IntermediateOperation> list) {
        return new Unsqueeze(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Unsqueeze";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
