package ai.vespa.rankingexpression.importer.onnx;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import onnx.Onnx;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TensorConverter.class */
public class TensorConverter {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TensorConverter$FloatValues.class */
    public static class FloatValues extends Values {
        private final Onnx.TensorProto tensorProto;

        FloatValues(Onnx.TensorProto tensorProto) {
            this.tensorProto = tensorProto;
        }

        @Override // ai.vespa.rankingexpression.importer.onnx.TensorConverter.Values
        double get(int i) {
            return this.tensorProto.getFloatData(i);
        }

        @Override // ai.vespa.rankingexpression.importer.onnx.TensorConverter.Values
        int size() {
            return this.tensorProto.getFloatDataCount();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TensorConverter$RawFloatValues.class */
    public static class RawFloatValues extends RawValues {
        private final FloatBuffer values;
        private final int size;

        RawFloatValues(Onnx.TensorProto tensorProto) {
            this.values = bytes(tensorProto).asFloatBuffer();
            this.size = this.values.remaining();
        }

        @Override // ai.vespa.rankingexpression.importer.onnx.TensorConverter.Values
        double get(int i) {
            return this.values.get(i);
        }

        @Override // ai.vespa.rankingexpression.importer.onnx.TensorConverter.Values
        int size() {
            return this.size;
        }
    }

    /* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TensorConverter$RawValues.class */
    private static abstract class RawValues extends Values {
        private RawValues() {
        }

        ByteBuffer bytes(Onnx.TensorProto tensorProto) {
            return tensorProto.getRawData().asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
        }
    }

    /* loaded from: input_file:ai/vespa/rankingexpression/importer/onnx/TensorConverter$Values.class */
    private static abstract class Values {
        private Values() {
        }

        abstract double get(int i);

        abstract int size();
    }

    TensorConverter() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType orderedTensorType) {
        Values readValuesOf = readValuesOf(tensorProto);
        IndexedTensor.BoundBuilder of = Tensor.Builder.of(orderedTensorType.type());
        for (int i = 0; i < readValuesOf.size(); i++) {
            of.cellByDirectIndex(orderedTensorType.toDirectIndex(i), readValuesOf.get(i));
        }
        return of.build();
    }

    private static Values readValuesOf(Onnx.TensorProto tensorProto) {
        if (tensorProto.hasRawData()) {
            switch (tensorProto.getDataType()) {
                case FLOAT:
                    return new RawFloatValues(tensorProto);
            }
        }
        switch (tensorProto.getDataType()) {
            case FLOAT:
                return new FloatValues(tensorProto);
        }
        throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + tensorProto.getDataType() + " to a Vespa tensor");
    }
}
