package javax.visrec.ri.spi;

import deepnetts.data.MLDataItem;
import deepnetts.data.TabularDataSet;
import deepnetts.net.FeedForwardNetwork;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.LossType;
import java.io.IOException;
import javax.visrec.ml.ClassifierCreationException;
import javax.visrec.ml.classification.BinaryClassifier;
import javax.visrec.ml.classification.NeuralNetBinaryClassifier;
import javax.visrec.ri.ml.classification.FeedForwardNetBinaryClassifier;
import javax.visrec.ri.util.DataSets;
import javax.visrec.spi.BinaryClassifierFactory;

/* loaded from: input_file:javax/visrec/ri/spi/FloatArrayBinaryClassifierFactory.class */
public class FloatArrayBinaryClassifierFactory implements BinaryClassifierFactory<float[]> {
    public Class<float[]> getTargetClass() {
        return float[].class;
    }

    public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<float[]> buildingBlock) throws ClassifierCreationException {
        FeedForwardNetwork.Builder builder = FeedForwardNetwork.builder();
        builder.addInputLayer(buildingBlock.getInputsNum());
        for (int i : buildingBlock.getHiddenLayers()) {
            builder.addFullyConnectedLayer(i);
        }
        builder.addOutputLayer(1, ActivationType.SIGMOID).lossFunction(LossType.CROSS_ENTROPY);
        FeedForwardNetwork build = builder.build();
        build.getTrainer().setMaxEpochs(buildingBlock.getMaxEpochs()).setMaxError(buildingBlock.getMaxError()).setLearningRate(buildingBlock.getLearningRate());
        try {
            TabularDataSet<MLDataItem> readCsv = DataSets.readCsv(buildingBlock.getTrainingFile(), buildingBlock.getInputsNum(), 1, true, ",");
            deepnetts.data.DataSets.normalizeMax(readCsv);
            build.train(readCsv);
            return new FeedForwardNetBinaryClassifier(build);
        } catch (IOException e) {
            throw new ClassifierCreationException("Failed to create training set based on training file", e);
        }
    }
}
