package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/Reshape.class */
public class Reshape extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;

    public Reshape(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        this.inputs.get(0).exportAsRankingFunction = true;
        if (this.inputs.size() == 2) {
            return typeWithShapeAsInput();
        }
        if (this.inputs.size() == 1) {
            return typeWithShapeAsAttribute();
        }
        throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + this.name + "', got " + this.inputs.size());
    }

    private OrderedTensorType typeWithShapeAsInput() {
        IntermediateOperation intermediateOperation = this.inputs.get(1);
        if (intermediateOperation.getConstantValue().isEmpty()) {
            throw new IllegalArgumentException("Reshape " + this.name + ": Shape input must be a constant.");
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        Tensor asTensor = intermediateOperation.getConstantValue().get().asTensor();
        ArrayList arrayList = new ArrayList(asTensor.type().rank());
        asTensor.valueIterator().forEachRemaining(d -> {
            arrayList.add(Integer.valueOf(d.intValue()));
        });
        for (int i = 0; i < arrayList.size(); i++) {
            if (arrayList.get(i).intValue() == 0) {
                if (i >= orderedTensorType.dimensions().size()) {
                    throw new IllegalArgumentException("Reshape " + this.name + ": 0 value for dimension not found in input");
                }
                arrayList.set(i, Integer.valueOf(((Long) orderedTensorType.dimensions().get(i).size().get()).intValue()));
            }
        }
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            if (arrayList.get(i2).intValue() < 0) {
                int intValue = arrayList.stream().reduce(1, (num, num2) -> {
                    return Integer.valueOf(num.intValue() * num2.intValue());
                }).intValue();
                arrayList.set(i2, Integer.valueOf(((-1) * OrderedTensorType.tensorSize(orderedTensorType.type()).intValue()) / (intValue == 0 ? -1 : intValue)));
            }
        }
        return buildOutputType(arrayList);
    }

    private OrderedTensorType typeWithShapeAsAttribute() {
        if (this.attributeMap.getList("shape").isEmpty() || this.attributeMap.getList("shape").get().size() == 0) {
            throw new IllegalArgumentException("Reshape in " + this.name + ": Shape attribute is empty.");
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        List<Value> list = this.attributeMap.getList("shape").get();
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Value> it = list.iterator();
        while (it.hasNext()) {
            int asDouble = (int) it.next().asDouble();
            if (asDouble < 0) {
                int reduce = (int) list.stream().mapToDouble((v0) -> {
                    return v0.asDouble();
                }).reduce(1.0d, (d, d2) -> {
                    return d * d2;
                });
                asDouble = ((-1) * reduce) / OrderedTensorType.tensorSize(orderedTensorType.type()).intValue();
            }
            arrayList.add(Integer.valueOf(asDouble));
        }
        return buildOutputType(arrayList);
    }

    private OrderedTensorType buildOutputType(List<Integer> list) {
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        for (int i = 0; i < list.size(); i++) {
            builder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), Integer.valueOf(i)), list.get(i).intValue()));
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!this.inputs.stream().map((v0) -> {
            return v0.type();
        }).allMatch((v0) -> {
            return v0.isPresent();
        }) || !this.inputs.stream().map((v0) -> {
            return v0.function();
        }).allMatch((v0) -> {
            return v0.isPresent();
        })) {
            return null;
        }
        return reshape(this.inputs.get(0).function().get(), this.inputs.get(0).type().get(), this.type);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        addConstraintsFrom(this.type, dimensionRenamer);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public Reshape withInputs(List<IntermediateOperation> list) {
        return new Reshape(modelName(), name(), list, this.attributeMap);
    }

    public TensorFunction<Reference> reshape(TensorFunction<Reference> tensorFunction, OrderedTensorType orderedTensorType, OrderedTensorType orderedTensorType2) {
        EmbracedNode embracedNode;
        if (!OrderedTensorType.tensorSize(orderedTensorType.type()).equals(OrderedTensorType.tensorSize(orderedTensorType2.type()))) {
            throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
        }
        String rankingExpressionFunctionName = this.inputs.get(0).rankingExpressionFunctionName();
        ArrayList arrayList = new ArrayList();
        EmbracedNode embracedNode2 = new EmbracedNode(unrollTensorExpression(orderedTensorType2));
        long j = 1;
        for (int i = 0; i < orderedTensorType.rank(); i++) {
            j *= ((Long) orderedTensorType.dimensions().get(i).size().get()).longValue();
        }
        for (int i2 = 0; i2 < orderedTensorType.rank(); i2++) {
            String name = orderedTensorType.dimensions().get(i2).name();
            long longValue = ((Long) orderedTensorType.dimensions().get(i2).size().get()).longValue();
            long j2 = j;
            j /= longValue;
            if (longValue == 1) {
                embracedNode = new EmbracedNode(new ConstantNode(DoubleValue.zero));
            } else if (i2 == orderedTensorType.rank() - 1) {
                embracedNode = new EmbracedNode(new ArithmeticNode(embracedNode2, ArithmeticOperator.MODULO, new ConstantNode(new DoubleValue(longValue))));
            } else {
                embracedNode = new EmbracedNode(new ArithmeticNode(new EmbracedNode(new ArithmeticNode(embracedNode2, ArithmeticOperator.MODULO, new ConstantNode(new DoubleValue(j2)))), ArithmeticOperator.DIVIDE, new ConstantNode(new DoubleValue(j))));
            }
            arrayList.add(new Slice.DimensionValue(Optional.of(name), TensorFunctionNode.wrapScalar(embracedNode)));
        }
        return Generate.bound(orderedTensorType2.type(), TensorFunctionNode.wrapScalar(new TensorFunctionNode(new com.yahoo.tensor.functions.Slice(new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(rankingExpressionFunctionName)), arrayList))));
    }

    private static ExpressionNode unrollTensorExpression(OrderedTensorType orderedTensorType) {
        if (orderedTensorType.rank() == 0) {
            return new ConstantNode(DoubleValue.zero);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i = 1;
        for (int size = orderedTensorType.dimensions().size() - 1; size >= 0; size--) {
            TensorType.Dimension dimension = orderedTensorType.dimensions().get(size);
            arrayList.add(0, new ReferenceNode(dimension.name()));
            if (i > 1) {
                arrayList2.add(0, ArithmeticOperator.MULTIPLY);
                arrayList.add(0, new ConstantNode(new DoubleValue(i)));
            }
            i = (int) (i * OrderedTensorType.dimensionSize(dimension).longValue());
            if (size > 0) {
                arrayList2.add(0, ArithmeticOperator.PLUS);
            }
        }
        return new ArithmeticNode(arrayList, arrayList2);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "Reshape";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
