package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:ai/vespa/modelintegration/evaluator/OnnxEvaluator.class */
public class OnnxEvaluator {
    private final OrtEnvironment environment;
    private final OrtSession session;

    public OnnxEvaluator(String str) {
        this(str, null);
    }

    public OnnxEvaluator(String str, OnnxEvaluatorOptions onnxEvaluatorOptions) {
        if (onnxEvaluatorOptions == null) {
            try {
                onnxEvaluatorOptions = new OnnxEvaluatorOptions();
            } catch (OrtException e) {
                throw new RuntimeException("ONNX Runtime exception", e);
            }
        }
        this.environment = OrtEnvironment.getEnvironment();
        this.session = this.environment.createSession(str, onnxEvaluatorOptions.getOptions());
    }

    public Tensor evaluate(Map<String, Tensor> map, String str) {
        Map map2 = null;
        try {
            try {
                Map<String, OnnxTensor> onnxTensors = TensorConverter.toOnnxTensors(map, this.environment, this.session);
                OrtSession.Result run = this.session.run(onnxTensors, Collections.singleton(str));
                try {
                    Tensor vespaTensor = TensorConverter.toVespaTensor(run.get(0));
                    if (run != null) {
                        run.close();
                    }
                    if (onnxTensors != null) {
                        onnxTensors.values().forEach((v0) -> {
                            v0.close();
                        });
                    }
                    return vespaTensor;
                } catch (Throwable th) {
                    if (run != null) {
                        try {
                            run.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (OrtException e) {
                throw new RuntimeException("ONNX Runtime exception", e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                map2.values().forEach((v0) -> {
                    v0.close();
                });
            }
            throw th3;
        }
    }

    public Map<String, Tensor> evaluate(Map<String, Tensor> map) {
        Map map2 = null;
        try {
            try {
                Map<String, OnnxTensor> onnxTensors = TensorConverter.toOnnxTensors(map, this.environment, this.session);
                HashMap hashMap = new HashMap();
                OrtSession.Result run = this.session.run(onnxTensors);
                try {
                    Iterator it = run.iterator();
                    while (it.hasNext()) {
                        Map.Entry entry = (Map.Entry) it.next();
                        hashMap.put((String) entry.getKey(), TensorConverter.toVespaTensor((OnnxValue) entry.getValue()));
                    }
                    if (run != null) {
                        run.close();
                    }
                    if (onnxTensors != null) {
                        onnxTensors.values().forEach((v0) -> {
                            v0.close();
                        });
                    }
                    return hashMap;
                } catch (Throwable th) {
                    if (run != null) {
                        try {
                            run.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (OrtException e) {
                throw new RuntimeException("ONNX Runtime exception", e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                map2.values().forEach((v0) -> {
                    v0.close();
                });
            }
            throw th3;
        }
    }

    public Map<String, TensorType> getInputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getInputInfo());
        } catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public Map<String, TensorType> getOutputInfo() {
        try {
            return TensorConverter.toVespaTypes(this.session.getOutputInfo());
        } catch (OrtException e) {
            throw new RuntimeException("ONNX Runtime exception", e);
        }
    }

    public static boolean isRuntimeAvailable() {
        return isRuntimeAvailable("");
    }

    public static boolean isRuntimeAvailable(String str) {
        try {
            new OnnxEvaluator(str);
            return true;
        } catch (NoClassDefFoundError | RuntimeException | UnsatisfiedLinkError e) {
            return false;
        }
    }
}
