package com.yahoo.vespa.model.container.xml;

import com.yahoo.component.ComponentId;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.path.Path;
import com.yahoo.text.XML;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigPayloadBuilder;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import com.yahoo.vespa.model.container.component.Component;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.w3c.dom.Element;
import org.w3c.dom.NamedNodeMap;
import org.xml.sax.SAXException;

/* loaded from: input_file:com/yahoo/vespa/model/container/xml/EmbedderTestCase.class */
public class EmbedderTestCase {
    private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder";
    private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder";

    @Test
    void testBundledEmbedder_selfhosted() throws IOException, SAXException {
        assertTransform("<component id='test' class='ai.vespa.embedding.BertBaseEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>    <transformerModel id='my_model_id' url='my-model-url' />    <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />  </config></component>", "<component id='test' class='ai.vespa.embedding.BertBaseEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>    <transformerModel id='my_model_id' url='my-model-url' />    <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />  </config></component>", false);
    }

    @Test
    void testBundledEmbedder_hosted() throws IOException, SAXException {
        assertTransform("<component id='test' class='ai.vespa.embedding.BertBaseEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>    <transformerModel model-id='minilm-l6-v2' />    <tokenizerVocab model-id='bert-base-uncased' />  </config></component>", "<component id='test' class='ai.vespa.embedding.BertBaseEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>      <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />      <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />  </config></component>", true);
    }

    @Test
    void testApplicationComponentWithModelReference_hosted() throws IOException, SAXException {
        assertTransform("<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>    <transformerModel model-id='minilm-l6-v2' />    <tokenizerVocab model-id='bert-base-uncased' />  </config></component>", "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>  <config name='embedding.bert-base-embedder'>      <transformerModel  model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />      <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />  </config></component>", true);
    }

    @Test
    void testUnknownModelId_hosted() throws IOException, SAXException {
        assertTransformThrows("<component id='test' class='ai.vespa.embedding.BertBaseEmbedder'>  <config name='embedding.bert-base-embedder'>    <transformerModel model-id='my_model_id' />    <tokenizerVocab model-id='my_vocab_id' />  </config></component>", "Unknown model id 'my_model_id' on 'transformerModel'", true);
    }

    @Test
    void testApplicationPackageWithEmbedder_selfhosted() throws Exception {
        ConfigPayloadBuilder configPayloadBuilder = ((Component) ((ApplicationContainerCluster) loadModel(Path.fromString("src/test/cfg/application/embed/"), false).getContainerClusters().get("container")).getComponentsMap().get(new ComponentId("transformer"))).getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
        Assertions.assertEquals("minilm-l6-v2 application-url \"\"", configPayloadBuilder.getObject("transformerModel").getValue());
        Assertions.assertEquals("\"\" \"\" files/vocab.txt", configPayloadBuilder.getObject("tokenizerVocab").getValue());
        Assertions.assertEquals("4", configPayloadBuilder.getObject("onnxIntraOpThreads").getValue());
    }

    @Test
    void testApplicationPackageWithEmbedder_hosted() throws Exception {
        ConfigPayloadBuilder configPayloadBuilder = ((Component) ((ApplicationContainerCluster) loadModel(Path.fromString("src/test/cfg/application/embed/"), true).getContainerClusters().get("container")).getComponentsMap().get(new ComponentId("transformer"))).getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
        Assertions.assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"", configPayloadBuilder.getObject("transformerModel").getValue());
        Assertions.assertEquals("\"\" \"\" files/vocab.txt", configPayloadBuilder.getObject("tokenizerVocab").getValue());
        Assertions.assertEquals("4", configPayloadBuilder.getObject("onnxIntraOpThreads").getValue());
    }

    @Test
    void testApplicationPackageWithApplicationEmbedder_selfhosted() throws Exception {
        ConfigPayloadBuilder configPayloadBuilder = ((Component) ((ApplicationContainerCluster) loadModel(Path.fromString("src/test/cfg/application/embed_generic/"), false).getContainerClusters().get("container")).getComponentsMap().get(new ComponentId("transformer"))).getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
        Assertions.assertEquals("minilm-l6-v2 application-url \"\"", configPayloadBuilder.getObject("model").getValue());
        Assertions.assertEquals("\"\" \"\" files/vocab.txt", configPayloadBuilder.getObject("vocab").getValue());
    }

    @Test
    void testApplicationPackageWithApplicationEmbedder_hosted() throws Exception {
        ConfigPayloadBuilder configPayloadBuilder = ((Component) ((ApplicationContainerCluster) loadModel(Path.fromString("src/test/cfg/application/embed_generic/"), true).getContainerClusters().get("container")).getComponentsMap().get(new ComponentId("transformer"))).getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
        Assertions.assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"", configPayloadBuilder.getObject("model").getValue());
        Assertions.assertEquals("\"\" \"\" files/vocab.txt", configPayloadBuilder.getObject("vocab").getValue());
    }

    private VespaModel loadModel(Path path, boolean z) throws Exception {
        return new VespaModel(new DeployState.Builder().properties(new TestProperties().setHostedVespa(z)).applicationPackage(FilesApplicationPackage.fromFile(path.toFile())).build());
    }

    private void assertTransform(String str, String str2, boolean z) throws IOException, SAXException {
        Element createElement = createElement(str);
        ModelIdResolver.resolveModelIds(createElement, z);
        assertSpec(createElement(str2), createElement);
    }

    private void assertSpec(Element element, Element element2) {
        Assertions.assertEquals(element.getTagName(), element2.getTagName());
        assertAttributes(element, element2);
        assertAttributes(element2, element);
        Assertions.assertEquals(XML.getValue(element).trim(), XML.getValue(element2).trim(), "Content of " + element.getTagName() + "' is identical");
        assertChildren(element, element2);
    }

    private void assertAttributes(Element element, Element element2) {
        NamedNodeMap attributes = element.getAttributes();
        for (int i = 0; i < attributes.getLength(); i++) {
            String nodeName = attributes.item(i).getNodeName();
            Assertions.assertEquals(element.getAttribute(nodeName), element2.getAttribute(nodeName), "Attribute '" + nodeName + "' is equal");
        }
    }

    private void assertChildren(Element element, Element element2) {
        List children = XML.getChildren(element);
        List children2 = XML.getChildren(element2);
        Assertions.assertEquals(children.size(), children2.size());
        for (int i = 0; i < children.size(); i++) {
            assertSpec((Element) children.get(i), (Element) children2.get(i));
        }
    }

    private void assertTransformThrows(String str, String str2, boolean z) throws IOException, SAXException {
        try {
            ModelIdResolver.resolveModelIds(createElement(str), z);
            Assertions.fail("Expected exception was not thrown: " + str2);
        } catch (IllegalArgumentException e) {
            Assertions.assertTrue(e.getMessage().contains(str2), "Expected error message not found");
        }
    }

    private Element createElement(String str) throws IOException, SAXException {
        return (Element) XML.getDocumentBuilder().parse(new ByteArrayInputStream(str.getBytes(StandardCharsets.UTF_8))).getFirstChild();
    }
}
