package ai.vespa.rankingexpression.importer.tensorflow;

import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.collections.Pair;
import com.yahoo.io.IOUtils;
import com.yahoo.system.ProcessExecuter;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.logging.Logger;
import org.tensorflow.SavedModelBundle;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.class */
public class TensorFlowImporter extends ModelImporter {
    private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
    private static final int[] onnxOpsetsToTry = {8, 10, 12};
    private final OnnxImporter onnxImporter = new OnnxImporter();

    @Override // ai.vespa.rankingexpression.importer.ModelImporter, ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter
    public boolean canImport(String str) {
        File file = new File(str);
        if (!file.isDirectory()) {
            return false;
        }
        for (File file2 : file.listFiles()) {
            if (file2.toString().endsWith(".pbtxt") || file2.toString().endsWith(".pb")) {
                return true;
            }
        }
        return false;
    }

    @Override // ai.vespa.rankingexpression.importer.ModelImporter
    public ImportedModel importModel(String str, String str2) {
        return convertToOnnxAndImport(str, str2);
    }

    public ImportedModel importModel(String str, String str2, SavedModelBundle savedModelBundle) {
        try {
            return convertIntermediateGraphToModel(GraphImporter.importGraph(str, savedModelBundle), str2);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not import TensorFlow model '" + savedModelBundle + "'", e);
        }
    }

    private ImportedModel convertToOnnxAndImport(String str, String str2) {
        Path path = null;
        try {
            try {
                Path createTempDirectory = Files.createTempDirectory("tf2onnx", new FileAttribute[0]);
                String str3 = createTempDirectory.toString() + File.separatorChar + "converted.onnx";
                String str4 = "";
                for (int i : onnxOpsetsToTry) {
                    log.info("Converting TensorFlow model '" + str2 + "' to ONNX with opset " + i + "...");
                    Pair<Integer, String> convertToOnnx = convertToOnnx(str2, str3, i);
                    if (((Integer) convertToOnnx.getFirst()).intValue() == 0) {
                        log.info("Conversion to ONNX with opset " + i + " successful.");
                        ImportedModel importModel = this.onnxImporter.importModel(str, str3);
                        if (createTempDirectory != null) {
                            IOUtils.recursiveDeleteDir(createTempDirectory.toFile());
                        }
                        return importModel;
                    }
                    log.fine("Conversion to ONNX with opset " + i + " failed. Reason: " + ((String) convertToOnnx.getSecond()));
                    str4 = (String) convertToOnnx.getSecond();
                }
                throw new IllegalArgumentException("Unable to convert TensorFlow model in '" + str2 + "' to ONNX. Reason: " + str4);
            } catch (IOException e) {
                throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + str2 + "'");
            }
        } catch (Throwable th) {
            if (0 != 0) {
                IOUtils.recursiveDeleteDir(path.toFile());
            }
            throw th;
        }
    }

    private Pair<Integer, String> convertToOnnx(String str, String str2, int i) throws IOException {
        return new ProcessExecuter().exec("vespa-convert-tf2onnx --saved-model " + str + " --output " + str2 + " --opset " + i);
    }
}
