package ai.onnxruntime;

import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;

/* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor.class */
public final class OnnxSparseTensor extends OnnxTensorLike {
    private final SparseTensorType sparseTensorType;
    private final Buffer indices;
    private final LongBuffer innerIndices;
    private final Buffer values;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.onnxruntime.OnnxSparseTensor$1, reason: invalid class name */
    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$OnnxJavaType = new int[OnnxJavaType.values().length];

        static {
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.FLOAT16.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.DOUBLE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT16.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT32.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT64.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.INT8.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UINT8.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.STRING.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxJavaType[OnnxJavaType.UNKNOWN.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType = new int[SparseTensorType.values().length];
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.COO.ordinal()] = 1;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.BLOCK_SPARSE.ordinal()] = 2;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.CSRC.ordinal()] = 3;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$ai$onnxruntime$OnnxSparseTensor$SparseTensorType[SparseTensorType.UNDEFINED.ordinal()] = 4;
            } catch (NoSuchFieldError e16) {
            }
        }
    }

    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$BlockSparseTensor.class */
    public static final class BlockSparseTensor extends SparseTensor<IntBuffer> {
        public BlockSparseTensor(IntBuffer intBuffer, long[] jArr, Buffer buffer, long[] jArr2, long[] jArr3, OnnxJavaType onnxJavaType, long j) {
            super(intBuffer, jArr, buffer, jArr2, jArr3, onnxJavaType, j);
            if (OrtUtil.elementCount(jArr2) != j) {
                throw new IllegalArgumentException("Expected " + j + " entries in the data shape, found " + Arrays.toString(jArr2));
            }
            if (j != buffer.remaining()) {
                throw new IllegalArgumentException("Expected " + j + " elements in the data buffer, found " + buffer.remaining());
            }
            if (OrtUtil.elementCount(jArr) != intBuffer.remaining()) {
                throw new IllegalArgumentException("Expected " + OrtUtil.elementCount(jArr) + " elements in the indices buffer, found " + intBuffer.remaining());
            }
            if (jArr2.length < 3) {
                throw new IllegalArgumentException("Expected [numBlocks, blockSize, blockSize] or larger, but data shape was " + Arrays.toString(jArr2));
            }
            if (jArr.length < 2) {
                throw new IllegalArgumentException("Expected [numBlocks, co-ordinates] or larger, but indices shape was " + Arrays.toString(jArr));
            }
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT32;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.BLOCK_SPARSE;
        }
    }

    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$COOTensor.class */
    public static final class COOTensor extends SparseTensor<LongBuffer> {
        public COOTensor(LongBuffer longBuffer, long[] jArr, Buffer buffer, long[] jArr2, OnnxJavaType onnxJavaType, long j) {
            super(longBuffer, jArr, buffer, new long[]{j}, jArr2, onnxJavaType, j);
            if (jArr.length > 2 || jArr.length == 0 || jArr[0] != j) {
                throw new IllegalArgumentException("Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found " + Arrays.toString(jArr));
            }
            long elementCount = OrtUtil.elementCount(jArr);
            if (elementCount != longBuffer.remaining()) {
                throw new IllegalArgumentException("Unexpected number of indices found in buffer, expected " + elementCount + " found " + longBuffer.remaining());
            }
            if (buffer.remaining() != j) {
                throw new IllegalArgumentException("Expected data.remaining() - " + buffer.remaining() + " to equal numNonZero - " + j);
            }
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.COO;
        }
    }

    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$CSRCTensor.class */
    public static final class CSRCTensor extends SparseTensor<LongBuffer> {
        private final LongBuffer innerIndices;

        public CSRCTensor(LongBuffer longBuffer, LongBuffer longBuffer2, Buffer buffer, long[] jArr, OnnxJavaType onnxJavaType, long j) {
            super(longBuffer, new long[]{longBuffer.remaining()}, buffer, new long[]{j}, jArr, onnxJavaType, j);
            this.innerIndices = longBuffer2;
            long j2 = jArr[0] + 1;
            if (longBuffer.remaining() != j2) {
                throw new IllegalArgumentException("Outer indices should be equal to the number of rows + 1 in the dense shape, found " + longBuffer.remaining() + ", expected " + j2);
            }
            if (longBuffer2.remaining() != j) {
                throw new IllegalArgumentException("Inner indices should be equal to the number of non-zero elements, found " + longBuffer2.remaining() + ", expected " + j);
            }
        }

        public long[] getInnerIndicesShape() {
            return new long[]{this.innerIndices.remaining()};
        }

        public LongBuffer getInnerIndices() {
            return this.innerIndices;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override // ai.onnxruntime.OnnxSparseTensor.SparseTensor
        public SparseTensorType getSparsityType() {
            return SparseTensorType.CSRC;
        }
    }

    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$SparseTensor.class */
    public static abstract class SparseTensor<T extends Buffer> {
        private final long[] indicesShape;
        private final long[] valuesShape;
        private final long[] denseShape;
        private final OnnxJavaType type;
        private final long numNonZero;
        final T indices;
        final Buffer values;

        SparseTensor(T t, long[] jArr, Buffer buffer, long[] jArr2, long[] jArr3, OnnxJavaType onnxJavaType, long j) {
            this.indices = t;
            this.indicesShape = jArr;
            this.values = buffer;
            this.valuesShape = jArr2;
            this.denseShape = jArr3;
            this.type = onnxJavaType;
            this.numNonZero = j;
            if (buffer.remaining() != j) {
                throw new IllegalArgumentException("Expected numNonZero and data.remaining to be equal, found " + j + " and " + buffer.remaining() + " respectively");
            }
            if (onnxJavaType == OnnxJavaType.STRING) {
                throw new IllegalArgumentException("String SparseTensors are not supported.");
            }
        }

        public long[] getDenseShape() {
            return this.denseShape;
        }

        public OnnxJavaType getType() {
            return this.type;
        }

        public long getNumNonZeroElements() {
            return this.numNonZero;
        }

        public T getIndices() {
            return this.indices;
        }

        public Buffer getValues() {
            return this.values;
        }

        public long[] getValuesShape() {
            return this.valuesShape;
        }

        public long[] getIndicesShape() {
            return this.indicesShape;
        }

        public abstract SparseTensorType getSparsityType();

        public abstract OnnxJavaType getIndicesType();
    }

    /* loaded from: input_file:ai/onnxruntime/OnnxSparseTensor$SparseTensorType.class */
    public enum SparseTensorType {
        UNDEFINED(0),
        COO(1),
        CSRC(2),
        BLOCK_SPARSE(4);

        public final int value;
        private static final SparseTensorType[] values = new SparseTensorType[5];

        SparseTensorType(int i) {
            this.value = i;
        }

        public static SparseTensorType mapFromInt(int i) {
            return (i <= 0 || i >= values.length) ? UNDEFINED : values[i];
        }

        static {
            values[0] = UNDEFINED;
            values[1] = COO;
            values[2] = CSRC;
            values[3] = UNDEFINED;
            values[4] = BLOCK_SPARSE;
        }
    }

    OnnxSparseTensor(long j, long j2, int i, TensorInfo tensorInfo) {
        this(j, j2, SparseTensorType.mapFromInt(i), tensorInfo, null, null, null);
    }

    OnnxSparseTensor(long j, long j2, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, Buffer buffer2) {
        this(j, j2, sparseTensorType, tensorInfo, buffer, null, buffer2);
    }

    OnnxSparseTensor(long j, long j2, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, LongBuffer longBuffer, Buffer buffer2) {
        super(j, j2, tensorInfo);
        this.sparseTensorType = sparseTensorType;
        this.indices = buffer;
        this.innerIndices = longBuffer;
        this.values = buffer2;
    }

    public static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, SparseTensor<T> sparseTensor) throws OrtException {
        return createSparseTensor(ortEnvironment, ortEnvironment.defaultAllocator, sparseTensor);
    }

    static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, SparseTensor<T> sparseTensor) throws OrtException {
        if (ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
        }
        TensorInfo constructFromSparseTensor = TensorInfo.constructFromSparseTensor(sparseTensor);
        OnnxJavaType indicesType = sparseTensor.getIndicesType();
        OrtUtil.BufferTuple prepareBuffer = OrtUtil.prepareBuffer(sparseTensor.getIndices(), indicesType);
        OrtUtil.BufferTuple prepareBuffer2 = OrtUtil.prepareBuffer(sparseTensor.getValues(), constructFromSparseTensor.type);
        if (!(prepareBuffer.data instanceof LongBuffer) && !(prepareBuffer.data instanceof IntBuffer)) {
            throw new IllegalStateException("Unexpected type of indices buffer, found " + prepareBuffer.data.getClass() + ", expected IntBuffer or LongBuffer");
        }
        switch (sparseTensor.getSparsityType()) {
            case COO:
            case BLOCK_SPARSE:
                return new OnnxSparseTensor(createSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, prepareBuffer.data, prepareBuffer.pos, prepareBuffer.size, prepareBuffer2.data, prepareBuffer2.pos, constructFromSparseTensor.shape, sparseTensor.getIndicesShape(), sparseTensor.getValuesShape(), constructFromSparseTensor.onnxType.value, sparseTensor.getSparsityType().value), ortAllocator.handle, sparseTensor.getSparsityType(), constructFromSparseTensor, prepareBuffer.data, prepareBuffer2.data);
            case CSRC:
                OrtUtil.BufferTuple prepareBuffer3 = OrtUtil.prepareBuffer(((CSRCTensor) sparseTensor).getInnerIndices(), indicesType);
                return new OnnxSparseTensor(createCSRCSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, prepareBuffer.data, prepareBuffer.pos, prepareBuffer.size, prepareBuffer3.data, prepareBuffer3.pos, prepareBuffer3.size, prepareBuffer2.data, prepareBuffer2.pos, constructFromSparseTensor.shape, sparseTensor.getValuesShape(), constructFromSparseTensor.onnxType.value), ortAllocator.handle, sparseTensor.getSparsityType(), constructFromSparseTensor, prepareBuffer.data, (LongBuffer) prepareBuffer3.data, prepareBuffer2.data);
            case UNDEFINED:
            default:
                throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
        }
    }

    @Override // ai.onnxruntime.OnnxValue
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_SPARSETENSOR;
    }

    @Override // ai.onnxruntime.OnnxValue
    public SparseTensor<? extends Buffer> getValue() throws OrtException {
        Buffer valuesBuffer = getValuesBuffer();
        long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        switch (this.sparseTensorType) {
            case COO:
                return new COOTensor((LongBuffer) getIndicesBuffer(), indicesShape, valuesBuffer, this.info.shape, this.info.type, valuesBuffer.remaining());
            case BLOCK_SPARSE:
                return new BlockSparseTensor((IntBuffer) getIndicesBuffer(), indicesShape, valuesBuffer, getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle), this.info.shape, this.info.type, valuesBuffer.remaining());
            case CSRC:
                return new CSRCTensor((LongBuffer) getIndicesBuffer(), getInnerIndicesBuffer(), valuesBuffer, this.info.shape, this.info.type, valuesBuffer.remaining());
            case UNDEFINED:
            default:
                throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
        }
    }

    @Override // ai.onnxruntime.OnnxValue, java.lang.AutoCloseable
    public void close() {
        close(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public SparseTensorType getSparseTensorType() {
        return this.sparseTensorType;
    }

    public Buffer getIndicesBuffer() {
        switch (this.sparseTensorType) {
            case COO:
            case CSRC:
                LongBuffer asLongBuffer = getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
                LongBuffer allocate = LongBuffer.allocate(asLongBuffer.capacity());
                allocate.put(asLongBuffer);
                allocate.rewind();
                return allocate;
            case BLOCK_SPARSE:
                IntBuffer asIntBuffer = getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asIntBuffer();
                IntBuffer allocate2 = IntBuffer.allocate(asIntBuffer.capacity());
                allocate2.put(asIntBuffer);
                allocate2.rewind();
                return allocate2;
            case UNDEFINED:
            default:
                throw new IllegalStateException("UNDEFINED sparse tensor type.");
        }
    }

    public LongBuffer getInnerIndicesBuffer() {
        if (this.sparseTensorType != SparseTensorType.CSRC) {
            throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + this.sparseTensorType);
        }
        LongBuffer asLongBuffer = getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
        LongBuffer allocate = LongBuffer.allocate(asLongBuffer.capacity());
        allocate.put(asLongBuffer);
        allocate.rewind();
        return allocate;
    }

    public Buffer getValuesBuffer() {
        ByteBuffer order = getValuesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$OnnxJavaType[this.info.type.ordinal()]) {
            case 1:
                FloatBuffer asFloatBuffer = order.asFloatBuffer();
                FloatBuffer allocate = FloatBuffer.allocate(asFloatBuffer.capacity());
                allocate.put(asFloatBuffer);
                allocate.rewind();
                return allocate;
            case 2:
                return Fp16Conversions.convertFp16BufferToFloatBuffer(order.asShortBuffer());
            case 3:
                return Fp16Conversions.convertBf16BufferToFloatBuffer(order.asShortBuffer());
            case 4:
                DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
                DoubleBuffer allocate2 = DoubleBuffer.allocate(asDoubleBuffer.capacity());
                allocate2.put(asDoubleBuffer);
                allocate2.rewind();
                return allocate2;
            case 5:
                ShortBuffer asShortBuffer = order.asShortBuffer();
                ShortBuffer allocate3 = ShortBuffer.allocate(asShortBuffer.capacity());
                allocate3.put(asShortBuffer);
                allocate3.rewind();
                return allocate3;
            case 6:
                IntBuffer asIntBuffer = order.asIntBuffer();
                IntBuffer allocate4 = IntBuffer.allocate(asIntBuffer.capacity());
                allocate4.put(asIntBuffer);
                allocate4.rewind();
                return allocate4;
            case 7:
                LongBuffer asLongBuffer = order.asLongBuffer();
                LongBuffer allocate5 = LongBuffer.allocate(asLongBuffer.capacity());
                allocate5.put(asLongBuffer);
                allocate5.rewind();
                return allocate5;
            case TensorInfo.MAX_DIMENSIONS /* 8 */:
            case 9:
            case 10:
                ByteBuffer allocate6 = ByteBuffer.allocate(order.capacity());
                allocate6.put(order);
                allocate6.rewind();
                return allocate6;
            case 11:
                throw new IllegalStateException("Unsupported data type String");
            case 12:
            default:
                throw new IllegalStateException("Unsupported data type");
        }
    }

    public long[] getIndicesShape() {
        return getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public long[] getInnerIndicesShape() {
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            return getInnerIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + this.sparseTensorType);
    }

    public long[] getValuesShape() {
        return getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    private native long[] getIndicesShape(long j, long j2);

    private native long[] getInnerIndicesShape(long j, long j2);

    private native long[] getValuesShape(long j, long j2);

    private native ByteBuffer getIndicesBuffer(long j, long j2);

    private native ByteBuffer getInnerIndicesBuffer(long j, long j2);

    private native ByteBuffer getValuesBuffer(long j, long j2);

    private native void close(long j, long j2);

    private static native long createCSRCSparseTensorFromBuffer(long j, long j2, Buffer buffer, int i, long j3, Buffer buffer2, int i2, long j4, Buffer buffer3, int i3, long[] jArr, long[] jArr2, int i4) throws OrtException;

    private static native long createSparseTensorFromBuffer(long j, long j2, Buffer buffer, int i, long j3, Buffer buffer2, int i2, long[] jArr, long[] jArr2, long[] jArr3, int i3, int i4) throws OrtException;
}
