package com.yahoo.vespa.model.ml;

import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.config.FileReference;
import com.yahoo.config.model.ApplicationPackageTester;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.schema.derived.FileDistributedConstants;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Assertions;
import org.xml.sax.SAXException;

/* loaded from: input_file:com/yahoo/vespa/model/ml/ImportedModelTester.class */
public class ImportedModelTester {
    private final List<MlModelImporter> importers;
    private final String modelName;
    private final Path applicationDir;
    private final DeployState deployState;

    public ImportedModelTester(String str, Path path) {
        this(str, path, new DeployState.Builder());
    }

    public ImportedModelTester(String str, Path path, DeployState.Builder builder) {
        this.importers = List.of(new TensorFlowImporter(), new OnnxImporter(), new LightGBMImporter(), new XGBoostImporter(), new VespaImporter());
        this.modelName = str;
        this.applicationDir = path;
        this.deployState = builder.applicationPackage(ApplicationPackageTester.create(path.toString()).app()).modelImporters(this.importers).build();
    }

    public VespaModel createVespaModel() {
        try {
            return new VespaModel(this.deployState);
        } catch (IOException | SAXException e) {
            throw new RuntimeException(e);
        }
    }

    public void assertLargeConstant(String str, VespaModel vespaModel, Optional<Long> optional) {
        try {
            Path append = Path.fromString("models.generated/" + this.modelName + "/constants").append(str + ".tbf");
            FileDistributedConstants.DistributableConstant distributableConstant = (FileDistributedConstants.DistributableConstant) vespaModel.rankProfileList().constants().asMap().get(str);
            Assertions.assertNotNull(distributableConstant);
            Assertions.assertEquals(str, distributableConstant.getName());
            Assertions.assertTrue(distributableConstant.getFileName().endsWith(append.toString()));
            Assertions.assertTrue(vespaModel.fileReferences().contains(new FileReference(distributableConstant.getFileName())));
            if (optional.isPresent()) {
                Path append2 = this.applicationDir.append(append);
                Assertions.assertTrue(append2.toFile().exists(), "Constant file '" + append2 + "' has been written");
                Assertions.assertEquals(optional.get().longValue(), TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(append2.toFile()))).size());
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
