package javax.visrec.ri.ml.regression;

import deepnetts.data.MLDataItem;
import deepnetts.net.FeedForwardNetwork;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.LossType;
import deepnetts.net.train.BackpropagationTrainer;
import deepnetts.util.Tensor;
import java.util.Map;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.regression.SimpleLinearRegression;

/* loaded from: input_file:javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork.class */
public class SimpleLinearRegressionNetwork extends SimpleLinearRegression<FeedForwardNetwork> {
    private final float[] input = new float[1];
    private final Tensor inputTensor = Tensor.create(1, 1, this.input);
    private float slope;
    private float intercept;

    /* loaded from: input_file:javax/visrec/ri/ml/regression/SimpleLinearRegressionNetwork$Builder.class */
    public static class Builder implements javax.visrec.util.Builder<SimpleLinearRegressionNetwork> {
        private SimpleLinearRegressionNetwork buildingBlock = new SimpleLinearRegressionNetwork();
        private float learningRate = 0.01f;
        private float maxError = 0.03f;
        private int maxEpochs = 1000;
        private DataSet<? extends MLDataItem> trainingSet;

        public Builder learningRate(float f) {
            this.learningRate = f;
            return this;
        }

        public Builder maxError(float f) {
            this.maxError = f;
            return this;
        }

        public Builder maxEpochs(int i) {
            this.maxEpochs = i;
            return this;
        }

        public Builder trainingSet(DataSet<? extends MLDataItem> dataSet) {
            this.trainingSet = dataSet;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public SimpleLinearRegressionNetwork m9build() {
            FeedForwardNetwork build = FeedForwardNetwork.builder().addInputLayer(1).addOutputLayer(1, ActivationType.LINEAR).lossFunction(LossType.MEAN_SQUARED_ERROR).build();
            BackpropagationTrainer backpropagationTrainer = new BackpropagationTrainer(build);
            backpropagationTrainer.setLearningRate(this.learningRate).setMaxError(this.maxError).setMaxEpochs(this.maxEpochs);
            backpropagationTrainer.train(this.trainingSet);
            this.buildingBlock.intercept = build.getOutputLayer().getBiases()[0];
            this.buildingBlock.slope = build.getOutputLayer().getWeights().get(0);
            this.buildingBlock.setModel(build);
            return this.buildingBlock;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public SimpleLinearRegressionNetwork m8build(Map map) {
            return m9build();
        }
    }

    public Float predict(Float f) {
        this.input[0] = f.floatValue();
        FeedForwardNetwork feedForwardNetwork = (FeedForwardNetwork) getModel();
        feedForwardNetwork.setInput(this.inputTensor);
        return Float.valueOf(feedForwardNetwork.getOutput()[0]);
    }

    public static Builder builder() {
        return new Builder();
    }

    public float getSlope() {
        return this.slope;
    }

    public float getIntercept() {
        return this.intercept;
    }
}
