package io.spokestack.spokestack.nlu.tensorflow;

import android.os.SystemClock;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import io.spokestack.spokestack.SpeechConfig;
import io.spokestack.spokestack.nlu.NLUContext;
import io.spokestack.spokestack.nlu.NLUResult;
import io.spokestack.spokestack.nlu.NLUService;
import io.spokestack.spokestack.nlu.Slot;
import io.spokestack.spokestack.nlu.tensorflow.Metadata;
import io.spokestack.spokestack.nlu.tensorflow.parsers.DigitsParser;
import io.spokestack.spokestack.nlu.tensorflow.parsers.IdentityParser;
import io.spokestack.spokestack.nlu.tensorflow.parsers.IntegerParser;
import io.spokestack.spokestack.nlu.tensorflow.parsers.SelsetParser;
import io.spokestack.spokestack.tensorflow.TensorflowModel;
import io.spokestack.spokestack.util.AsyncResult;
import io.spokestack.spokestack.util.EventTracer;
import io.spokestack.spokestack.util.TraceListener;
import io.spokestack.spokestack.util.Tuple;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

/* loaded from: input_file:io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.class */
public final class TensorflowNLU implements NLUService {
    private final ExecutorService executor;
    private final NLUContext context;
    private TextEncoder textEncoder;
    private int sepTokenId;
    private int padTokenId;
    private Thread loadThread;
    private TensorflowModel nluModel;
    private TFNLUOutput outputParser;
    private int maxTokens;
    private volatile boolean ready;

    /* loaded from: input_file:io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU$Builder.class */
    public static class Builder {
        private NLUContext context;
        private TensorflowModel.Loader modelLoader;
        private ThreadFactory threadFactory;
        private TextEncoder textEncoder;
        private final List<TraceListener> traceListeners = new ArrayList();
        private final Map<String, String> slotParserClasses = new HashMap();
        private SpeechConfig config = new SpeechConfig();

        public Builder() {
            this.config.put("trace-level", Integer.valueOf(EventTracer.Level.ERROR.value()));
            registerSlotParser("digits", DigitsParser.class.getName());
            registerSlotParser("integer", IntegerParser.class.getName());
            registerSlotParser("selset", SelsetParser.class.getName());
            registerSlotParser("entity", IdentityParser.class.getName());
        }

        Builder setThreadFactory(ThreadFactory threadFactory) {
            this.threadFactory = threadFactory;
            return this;
        }

        public Builder setConfig(SpeechConfig speechConfig) {
            this.config = speechConfig;
            return this;
        }

        Builder setModelLoader(TensorflowModel.Loader loader) {
            this.modelLoader = loader;
            return this;
        }

        Builder setTextEncoder(TextEncoder textEncoder) {
            this.textEncoder = textEncoder;
            return this;
        }

        public Builder setProperty(String str, Object obj) {
            this.config.put(str, obj);
            return this;
        }

        public Builder registerSlotParser(String str, String str2) {
            this.slotParserClasses.put(str, str2);
            return this;
        }

        public Builder addTraceListener(TraceListener traceListener) {
            this.traceListeners.add(traceListener);
            return this;
        }

        public TensorflowNLU build() {
            this.context = new NLUContext(this.config);
            Iterator<TraceListener> it = this.traceListeners.iterator();
            while (it.hasNext()) {
                this.context.addTraceListener(it.next());
            }
            if (this.modelLoader == null) {
                this.modelLoader = new TensorflowModel.Loader();
            }
            if (this.textEncoder == null) {
                this.textEncoder = new WordpieceTextEncoder(this.config, this.context);
            }
            if (this.threadFactory == null) {
                this.threadFactory = Thread::new;
            }
            return new TensorflowNLU(this);
        }
    }

    private TensorflowNLU(Builder builder) {
        this.executor = Executors.newSingleThreadExecutor();
        this.ready = false;
        this.context = builder.context;
        load(transferSlotParsers(builder.slotParserClasses, builder.config), builder.textEncoder, builder.modelLoader, builder.threadFactory);
    }

    private SpeechConfig transferSlotParsers(Map<String, String> map, SpeechConfig speechConfig) {
        for (Map.Entry<String, String> entry : map.entrySet()) {
            String str = "slot-" + entry.getKey();
            if (!speechConfig.containsKey(str)) {
                speechConfig.put(str, entry.getValue());
            }
        }
        return speechConfig;
    }

    public TensorflowNLU(SpeechConfig speechConfig, NLUContext nLUContext) {
        this.executor = Executors.newSingleThreadExecutor();
        this.ready = false;
        this.context = nLUContext;
        load(speechConfig, new WordpieceTextEncoder(speechConfig, this.context), new TensorflowModel.Loader(), Thread::new);
    }

    private void load(SpeechConfig speechConfig, TextEncoder textEncoder, TensorflowModel.Loader loader, ThreadFactory threadFactory) {
        String string = speechConfig.getString("nlu-model-path");
        String string2 = speechConfig.getString("nlu-metadata-path");
        Map<String, String> slotParsers = getSlotParsers(speechConfig);
        this.textEncoder = textEncoder;
        this.loadThread = threadFactory.newThread(() -> {
            loadModel(loader, string2, string);
            initParsers(slotParsers);
        });
        this.loadThread.start();
        this.padTokenId = textEncoder.encodeSingle("[PAD]");
        this.sepTokenId = textEncoder.encodeSingle("[SEP]");
    }

    private Map<String, String> getSlotParsers(SpeechConfig speechConfig) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Object> entry : speechConfig.getParams().entrySet()) {
            if (entry.getKey().startsWith("slot-")) {
                hashMap.put(entry.getKey().replace("slot-", ""), String.valueOf(entry.getValue()));
            }
        }
        return hashMap;
    }

    private void initParsers(Map<String, String> map) {
        HashMap hashMap = new HashMap();
        for (String str : map.keySet()) {
            try {
                hashMap.put(str, (SlotParser) Class.forName(map.get(str)).getConstructor(new Class[0]).newInstance(new Object[0]));
            } catch (Exception e) {
                this.context.traceError("Error loading slot parsers: %s", e.getLocalizedMessage());
            }
        }
        this.outputParser.registerSlotParsers(hashMap);
        this.ready = true;
    }

    /* JADX WARN: Failed to calculate best type for var: r11v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r11v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r12v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r12v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 11, insn: 0x00ee: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r11 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:54:0x00ee */
    /* JADX WARN: Not initialized variable reg: 12, insn: 0x00f3: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r12 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:56:0x00f3 */
    /* JADX WARN: Type inference failed for: r11v0, types: [java.io.FileReader] */
    /* JADX WARN: Type inference failed for: r12v0, types: [java.lang.Throwable] */
    private void loadModel(TensorflowModel.Loader loader, String str, String str2) {
        try {
            try {
                FileReader fileReader = new FileReader(str);
                Throwable th = null;
                JsonReader jsonReader = new JsonReader(fileReader);
                Throwable th2 = null;
                try {
                    try {
                        Metadata metadata = (Metadata) new Gson().fromJson(jsonReader, Metadata.class);
                        this.nluModel = loader.setPath(str2).load();
                        this.maxTokens = this.nluModel.inputs(0).capacity() / this.nluModel.getInputSize();
                        this.outputParser = new TFNLUOutput(metadata);
                        warmup();
                        if (jsonReader != null) {
                            if (0 != 0) {
                                try {
                                    jsonReader.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                jsonReader.close();
                            }
                        }
                        if (fileReader != null) {
                            if (0 != 0) {
                                try {
                                    fileReader.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                fileReader.close();
                            }
                        }
                    } catch (Throwable th5) {
                        th2 = th5;
                        throw th5;
                    }
                } catch (Throwable th6) {
                    if (jsonReader != null) {
                        if (th2 != null) {
                            try {
                                jsonReader.close();
                            } catch (Throwable th7) {
                                th2.addSuppressed(th7);
                            }
                        } else {
                            jsonReader.close();
                        }
                    }
                    throw th6;
                }
            } catch (IOException e) {
                this.context.traceError("Error loading NLU model: %s", e.getLocalizedMessage());
            }
        } finally {
        }
    }

    private void warmup() {
        this.nluModel.inputs(0).rewind();
        for (int i = 0; i < this.maxTokens; i++) {
            this.nluModel.inputs(0).putInt(0);
        }
        this.nluModel.run();
    }

    int getMaxTokens() {
        return this.maxTokens;
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.executor.shutdownNow();
        this.nluModel.close();
        this.nluModel = null;
        this.textEncoder = null;
        this.outputParser = null;
    }

    public AsyncResult<NLUResult> classify(String str) {
        return classify(str, this.context);
    }

    @Override // io.spokestack.spokestack.nlu.NLUService
    public AsyncResult<NLUResult> classify(String str, NLUContext nLUContext) {
        ensureReady();
        AsyncResult<NLUResult> asyncResult = new AsyncResult<>(() -> {
            try {
                try {
                    long elapsedRealtime = SystemClock.elapsedRealtime();
                    NLUResult tfClassify = tfClassify(str, nLUContext);
                    if (nLUContext.canTrace(EventTracer.Level.PERF)) {
                        nLUContext.tracePerf("Classification: %5dms", Long.valueOf(SystemClock.elapsedRealtime() - elapsedRealtime));
                    }
                    return tfClassify;
                } catch (Exception e) {
                    NLUResult build = new NLUResult.Builder(str).withError(e).build();
                    nLUContext.reset();
                    return build;
                }
            } finally {
                nLUContext.reset();
            }
        });
        this.executor.submit(asyncResult);
        return asyncResult;
    }

    private void ensureReady() {
        if (this.ready) {
            return;
        }
        try {
            this.loadThread.join();
        } catch (InterruptedException e) {
            this.context.traceError("Interrupted during loading", new Object[0]);
        }
    }

    private NLUResult tfClassify(String str, NLUContext nLUContext) {
        EncodedTokens encode = this.textEncoder.encode(str);
        nLUContext.traceDebug("Token IDs: %s", encode.getIds());
        int[] pad = pad(encode.getIds());
        this.nluModel.inputs(0).rewind();
        for (int i : pad) {
            this.nluModel.inputs(0).putInt(i);
        }
        long elapsedRealtime = SystemClock.elapsedRealtime();
        this.nluModel.run();
        if (nLUContext.canTrace(EventTracer.Level.PERF)) {
            nLUContext.tracePerf("Inference: %5dms", Long.valueOf(SystemClock.elapsedRealtime() - elapsedRealtime));
        }
        Tuple<Metadata.Intent, Float> intent = this.outputParser.getIntent(this.nluModel.outputs(0));
        Metadata.Intent first = intent.first();
        nLUContext.traceDebug("Intent: %s", first.getName());
        Map<String, Slot> parseSlots = this.outputParser.parseSlots(first, this.outputParser.getSlots(nLUContext, encode, this.nluModel.outputs(1)));
        nLUContext.traceDebug("Slots: %s", parseSlots.toString());
        return new NLUResult.Builder(str).withIntent(first.getName()).withConfidence(intent.second().floatValue()).withSlots(parseSlots).build();
    }

    private int[] pad(List<Integer> list) {
        if (list.size() > this.maxTokens) {
            throw new IllegalArgumentException("input: " + list.size() + " tokens; max input length is: " + this.maxTokens);
        }
        int[] iArr = new int[this.maxTokens];
        for (int i = 0; i < list.size(); i++) {
            iArr[i] = list.get(i).intValue();
        }
        if (list.size() < this.maxTokens) {
            iArr[list.size()] = this.sepTokenId;
            if (this.padTokenId != 0) {
                for (int size = list.size() + 2; size < iArr.length; size++) {
                    iArr[size] = this.padTokenId;
                }
            }
        }
        return iArr;
    }

    public void addListener(TraceListener traceListener) {
        this.context.addTraceListener(traceListener);
    }

    public void removeListener(TraceListener traceListener) {
        this.context.removeTraceListener(traceListener);
    }
}
