package com.yahoo.vespa.model.ml;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.component.ComponentId;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.provision.InMemoryProvisioner;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import java.util.HashMap;
import java.util.Set;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.class */
public class StatelessOnnxEvaluationTest {
    @Test
    void testStatelessOnnxModelNameCollision() {
        Assumptions.assumeTrue(OnnxEvaluator.isRuntimeAvailable());
        Path fromString = Path.fromString("src/test/cfg/application/onnx_name_collision");
        try {
            ApplicationContainerCluster applicationContainerCluster = (ApplicationContainerCluster) new ImportedModelTester("onnx", fromString).createVespaModel().getContainerClusters().get("container");
            RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
            applicationContainerCluster.getConfig(builder);
            RankProfilesConfig rankProfilesConfig = new RankProfilesConfig(builder);
            Assertions.assertEquals(2, rankProfilesConfig.rankprofile().size());
            Set set = (Set) rankProfilesConfig.rankprofile().stream().map(rankprofile -> {
                return rankprofile.name();
            }).collect(Collectors.toSet());
            Assertions.assertTrue(set.contains("foobar"));
            Assertions.assertTrue(set.contains("barfoo"));
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            throw th;
        }
    }

    @Test
    void testStatelessOnnxModelEvaluation() throws Exception {
        Assumptions.assumeTrue(OnnxEvaluator.isRuntimeAvailable());
        Path fromString = Path.fromString("src/test/cfg/application/onnx");
        Path append = fromString.append("copy");
        try {
            assertModelEvaluation(new ImportedModelTester("onnx_rt", fromString, new DeployState.Builder()).createVespaModel(), fromString, false);
            append.toFile().mkdirs();
            IOUtils.copy(fromString.append("services.xml").toString(), append.append("services.xml").toString());
            IOUtils.copyDirectory(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            IOUtils.copyDirectory(fromString.append(ApplicationPackage.SCHEMAS_DIR).toFile(), append.append(ApplicationPackage.SCHEMAS_DIR).toFile());
            assertModelEvaluation(new ImportedModelTester("onnx_rt", append).createVespaModel(), fromString, false);
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    @Test
    void testStatelessOnnxModelEvaluationWithGpu() {
        Assumptions.assumeTrue(OnnxEvaluator.isRuntimeAvailable());
        DeployState.Builder properties = new DeployState.Builder().modelHostProvisioner(new InMemoryProvisioner(6, new NodeResources(4.0d, 16.0d, 125.0d, 10.0d, NodeResources.DiskSpeed.fast, NodeResources.StorageType.local, NodeResources.Architecture.x86_64, new NodeResources.GpuResources(1, 16.0d)), false)).properties(new TestProperties().setMultitenant(true).setHostedVespa(true));
        Path fromString = Path.fromString("src/test/cfg/application/onnx");
        try {
            assertModelEvaluation(new ImportedModelTester("onnx_rt", fromString, properties).createVespaModel(), fromString, true);
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            throw th;
        }
    }

    private void assertModelEvaluation(VespaModel vespaModel, Path path, boolean z) {
        ApplicationContainerCluster applicationContainerCluster = (ApplicationContainerCluster) vespaModel.getContainerClusters().get("container");
        Assertions.assertNotNull(applicationContainerCluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
        RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
        applicationContainerCluster.getConfig(builder);
        RankProfilesConfig rankProfilesConfig = new RankProfilesConfig(builder);
        RankingConstantsConfig.Builder builder2 = new RankingConstantsConfig.Builder();
        applicationContainerCluster.getConfig(builder2);
        RankingConstantsConfig rankingConstantsConfig = new RankingConstantsConfig(builder2);
        RankingExpressionsConfig.Builder builder3 = new RankingExpressionsConfig.Builder();
        applicationContainerCluster.getConfig(builder3);
        RankingExpressionsConfig build = builder3.build();
        OnnxModelsConfig.Builder builder4 = new OnnxModelsConfig.Builder();
        applicationContainerCluster.getConfig(builder4);
        OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(builder4);
        Assertions.assertEquals(1, rankProfilesConfig.rankprofile().size());
        Assertions.assertTrue(((Set) rankProfilesConfig.rankprofile().stream().map(rankprofile -> {
            return rankprofile.name();
        }).collect(Collectors.toSet())).contains("mul"));
        OnnxModelsConfig.Model model = (OnnxModelsConfig.Model) onnxModelsConfig.model().get(0);
        Assertions.assertEquals(2, model.stateless_intraop_threads());
        Assertions.assertEquals(-1, model.stateless_interop_threads());
        Assertions.assertEquals("", model.stateless_execution_mode());
        Assertions.assertEquals(Boolean.valueOf(z), Boolean.valueOf(model.gpu_device_required()));
        Assertions.assertEquals(0, model.gpu_device());
        HashMap hashMap = new HashMap();
        for (OnnxModelsConfig.Model model2 : onnxModelsConfig.model()) {
            hashMap.put(model2.fileref().value(), path.append(model2.fileref().value()).toFile());
        }
        try {
            ModelsEvaluator modelsEvaluator = new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, build, onnxModelsConfig, MockFileAcquirer.returnFiles(hashMap));
            Assertions.assertEquals(1, modelsEvaluator.models().size());
            Assertions.assertEquals(6.0d, ((Model) modelsEvaluator.models().get("mul")).evaluatorOf(new String[0]).bind("input1", Tensor.from("tensor<float>(d0[1]):[2]")).bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")).evaluate().sum().asDouble(), 1.0E-9d);
        } catch (IllegalArgumentException e) {
            boolean z2 = false;
            Throwable th = e;
            while (true) {
                Throwable th2 = th;
                if (!z || th2 == null) {
                    break;
                }
                if (th2.getMessage().equals("GPU device is required, but CUDA initialization failed")) {
                    z2 = true;
                    break;
                }
                th = th2.getCause();
            }
            if (!z2) {
                throw e;
            }
        }
    }
}
