package ml.shifu.shifu.tensorflow;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.container.obj.GenericModelConfig;
import ml.shifu.shifu.core.Computable;
import org.encog.ml.data.MLData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:ml/shifu/shifu/tensorflow/TensorflowModel.class */
public class TensorflowModel implements Computable {
    private static final Logger LOG = LoggerFactory.getLogger(TensorflowModel.class);
    public Map<String, Object> properties = new HashMap();
    private boolean initiate = false;
    private String modelPath;
    private SavedModelBundle smb;
    private GenericModelConfig config;
    private String[] tags;
    private String[] inputNames;
    private String outputNames;

    /* JADX WARN: Multi-variable type inference failed */
    public double compute(MLData mLData) {
        LOG.error("tensorflow compute start.");
        double d = Double.MIN_VALUE;
        if (this.initiate && this.smb != null) {
            Session.Runner runner = this.smb.session().runner();
            double[] data = mLData.getData();
            float[] fArr = new float[data.length];
            for (int i = 0; i < data.length; i++) {
                fArr[i] = (float) data[i];
            }
            runner.feed(this.inputNames[0], Tensor.create(new float[]{fArr}));
            for (int i2 = 1; i2 < this.inputNames.length; i2++) {
                try {
                    runner.feed(this.inputNames[i2], Tensor.create(this.properties.get(this.inputNames[i2])));
                } catch (Exception e) {
                    LOG.error("Invalid input, {}", e);
                }
            }
            runner.fetch(this.outputNames);
            d = ((float[][]) ((Tensor) runner.run().get(0)).copyTo(new float[1][1]))[0][0];
        }
        LOG.error("return result {}", Double.valueOf(d));
        return d;
    }

    public void init(GenericModelConfig genericModelConfig) {
        LOG.info("Init tensorflow model");
        if (this.initiate) {
            return;
        }
        if (genericModelConfig == null) {
            LOG.error("Config is null");
            throw new RuntimeException("Config is null");
        }
        this.config = genericModelConfig;
        this.properties = this.config.getProperties();
        if (this.properties == null || this.properties.size() == 0) {
            LOG.error("Properties is null");
            throw new RuntimeException("Properties is null");
        }
        this.modelPath = (String) this.properties.get("modelpath");
        this.inputNames = (String[]) genericModelConfig.getInputnames().toArray(new String[0]);
        Object obj = this.properties.get("outputnames");
        if (obj instanceof String) {
            this.outputNames = (String) this.properties.get("outputnames");
        } else if (obj instanceof String[]) {
            String[] strArr = (String[]) obj;
            if (strArr.length != 1) {
                throw new IllegalArgumentException("Output now only support single output in inference.");
            }
            this.outputNames = strArr[0];
        }
        List list = (List) this.properties.get("tags");
        this.tags = (String[]) list.toArray(new String[list.size()]);
        LOG.info("Debug: properties : {}", this.properties);
        if (this.modelPath == null || this.modelPath.isEmpty()) {
            LOG.error("Model path is null");
            throw new RuntimeException("Model path is null");
        }
        if (this.inputNames == null || this.inputNames.length == 0) {
            LOG.error("Input names is null");
            throw new RuntimeException("Input names is null");
        }
        if (this.outputNames == null || this.outputNames.isEmpty()) {
            LOG.error("Output names is null");
            throw new RuntimeException("Output names is null");
        }
        if (this.tags == null || this.tags.length == 0) {
            LOG.error("Tags is null");
            throw new RuntimeException("Tags is null");
        }
        LOG.info("Load model from {}.", this.modelPath);
        this.smb = SavedModelBundle.load(this.modelPath, this.tags);
        LOG.info("Init tensorflow model done.");
        this.initiate = true;
    }

    public void releaseResource() {
    }
}
