package javax.visrec.ri.ml.classification;

import deepnetts.data.ExampleImage;
import deepnetts.data.ImageSet;
import deepnetts.net.ConvolutionalNetwork;
import deepnetts.net.train.BackpropagationTrainer;
import deepnetts.net.train.opt.OptimizerType;
import deepnetts.util.DeepNettsException;
import deepnetts.util.FileIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.visrec.AbstractImageClassifier;
import javax.visrec.ml.ClassifierCreationException;

/* loaded from: input_file:javax/visrec/ri/ml/classification/ImageClassifierNetwork.class */
public class ImageClassifierNetwork extends AbstractImageClassifier<BufferedImage, ConvolutionalNetwork> {
    private int inputWidth;
    private int inputHeight;

    /* loaded from: input_file:javax/visrec/ri/ml/classification/ImageClassifierNetwork$Builder.class */
    public static class Builder implements javax.visrec.util.Builder<ImageClassifierNetwork> {
        private final Logger LOGGER = Logger.getLogger(ImageClassifierNetwork.class.getName());

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public ImageClassifierNetwork m5build() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public ImageClassifierNetwork build(Map<String, Object> map) {
            int parseInt = Integer.parseInt(String.valueOf(map.get("visrec.imageWidth")));
            int parseInt2 = Integer.parseInt(String.valueOf(map.get("visrec.imageHeight")));
            String valueOf = String.valueOf(map.get("visrec.labelsFile"));
            String valueOf2 = String.valueOf(map.get("visrec.trainingFile"));
            float parseFloat = Float.parseFloat(String.valueOf(map.get("visrec.sgd.maxError")));
            int parseInt3 = Integer.parseInt(String.valueOf(map.get("visrec.sgd.maxEpochs")));
            float parseFloat2 = Float.parseFloat(String.valueOf(map.get("visrec.sgd.learningRate")));
            String valueOf3 = String.valueOf(map.get("visrec.model.saveTo"));
            ImageSet imageSet = new ImageSet(parseInt, parseInt2);
            this.LOGGER.info("Loading images...");
            imageSet.loadLabels(new File(valueOf));
            try {
                imageSet.loadImages(new File(valueOf2), 1000);
                imageSet.shuffle();
                this.LOGGER.info("Done!");
                this.LOGGER.info("Creating neural network...");
                try {
                    ConvolutionalNetwork createFromJson = FileIO.createFromJson(new File(String.valueOf(map.get("visrec.model.deepnetts"))));
                    createFromJson.setOutputLabels(imageSet.getTargetNames());
                    this.LOGGER.info("Training neural network");
                    createFromJson.setOutputLabels(imageSet.getTargetNames());
                    BackpropagationTrainer trainer = createFromJson.getTrainer();
                    trainer.setLearningRate(parseFloat2).setMaxError(parseFloat).setMaxEpochs(parseInt3).setBatchMode(false).setOptimizer(OptimizerType.SGD);
                    trainer.train(imageSet);
                    ImageClassifierNetwork imageClassifierNetwork = new ImageClassifierNetwork(createFromJson);
                    try {
                        FileIO.writeToFile(createFromJson, valueOf3);
                    } catch (IOException e) {
                        Logger.getLogger(ImageClassifierNetwork.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    }
                    return imageClassifierNetwork;
                } catch (IOException e2) {
                    Logger.getLogger(ImageClassifierNetwork.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
                    return null;
                }
            } catch (DeepNettsException e3) {
                Logger.getLogger(ImageClassifierNetwork.class.getName()).log(Level.SEVERE, (String) null, e3);
                return null;
            }
        }

        /* renamed from: build, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m4build(Map map) throws ClassifierCreationException {
            return build((Map<String, Object>) map);
        }
    }

    public ImageClassifierNetwork(ConvolutionalNetwork convolutionalNetwork) {
        super(BufferedImage.class, convolutionalNetwork);
    }

    public Map<String, Float> classify(BufferedImage bufferedImage) {
        ExampleImage exampleImage = new ExampleImage(bufferedImage);
        ConvolutionalNetwork convolutionalNetwork = (ConvolutionalNetwork) getModel();
        convolutionalNetwork.setInput(exampleImage.getInput());
        float[] output = convolutionalNetwork.getOutput();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < output.length; i++) {
            if (output[i] > getThreshold()) {
                hashMap.put(convolutionalNetwork.getOutputLabel(i), Float.valueOf(output[i]));
            }
        }
        return hashMap;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public static javax.visrec.util.Builder<ImageClassifierNetwork> builder() {
        return new Builder();
    }
}
