package org.jpmml.rexp;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Entity;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;

/* loaded from: input_file:org/jpmml/rexp/NNetConverter.class */
public class NNetConverter extends ModelConverter<RGenericVector> {

    /* renamed from: org.jpmml.rexp.NNetConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/rexp/NNetConverter$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public NNetConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RStringVector rStringVector = (RStringVector) rGenericVector.getValue("lev", true);
        RExp value = rGenericVector.getValue("terms");
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("xlevels");
        RStringVector rStringVector2 = (RStringVector) rGenericVector.getValue("coefnames");
        Formula createFormula = FormulaUtil.createFormula(value, new XLevelsFormulaContext(rGenericVector2), rExpEncoder);
        SchemaUtil.setLabel(createFormula, value, rStringVector, rExpEncoder);
        SchemaUtil.addFeatures(createFormula, rStringVector2, true, rExpEncoder);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel */
    public Model mo0encodeModel(Schema schema) {
        MiningFunction miningFunction;
        List neurons;
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RDoubleVector rDoubleVector = (RDoubleVector) rGenericVector.getValue("n");
        RBooleanVector rBooleanVector = (RBooleanVector) rGenericVector.getValue("linout", true);
        RBooleanVector rBooleanVector2 = (RBooleanVector) rGenericVector.getValue("softmax", true);
        RBooleanVector rBooleanVector3 = (RBooleanVector) rGenericVector.getValue("censored", true);
        RDoubleVector rDoubleVector2 = (RDoubleVector) rGenericVector.getValue("wts");
        RStringVector rStringVector = (RStringVector) rGenericVector.getValue("lev", true);
        if (rDoubleVector.size() != 3) {
            throw new IllegalArgumentException();
        }
        ContinuousLabel label = schema.getLabel();
        List features = schema.getFeatures();
        if (rStringVector != null) {
            miningFunction = MiningFunction.CLASSIFICATION;
        } else {
            if (rBooleanVector != null && !rBooleanVector.asScalar().booleanValue()) {
                throw new IllegalArgumentException();
            }
            miningFunction = MiningFunction.REGRESSION;
        }
        if (ValueUtil.asInt(rDoubleVector.getValue(0)) != features.size()) {
            throw new IllegalArgumentException();
        }
        NeuralInputs createNeuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
        int i = 0;
        ArrayList arrayList = new ArrayList();
        List neuralInputs = createNeuralInputs.getNeuralInputs();
        int asInt = ValueUtil.asInt(rDoubleVector.getValue(1));
        if (asInt > 0) {
            NeuralLayer activationFunction = encodeNeuralLayer("hidden", asInt, neuralInputs, rDoubleVector2, 0).setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
            i = 0 + (asInt * (neuralInputs.size() + 1));
            arrayList.add(activationFunction);
            neuralInputs = activationFunction.getNeurons();
        }
        int asInt2 = ValueUtil.asInt(rDoubleVector.getValue(2));
        if (asInt2 == 1) {
            NeuralLayer encodeNeuralLayer = encodeNeuralLayer("output", asInt2, neuralInputs, rDoubleVector2, i);
            int size = i + (asInt2 * (neuralInputs.size() + 1));
            arrayList.add(encodeNeuralLayer);
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                case 2:
                    NeuralLayer encodeLogisticTransform = encodeLogisticTransform(getOnlyNeuron(encodeNeuralLayer));
                    arrayList.add(encodeLogisticTransform);
                    encodeNeuralLayer = encodeLabelBinarizerTransform(getOnlyNeuron(encodeLogisticTransform));
                    arrayList.add(encodeNeuralLayer);
                    break;
            }
            neurons = encodeNeuralLayer.getNeurons();
        } else {
            if (asInt2 <= 1) {
                throw new IllegalArgumentException();
            }
            NeuralLayer encodeNeuralLayer2 = encodeNeuralLayer("output", asInt2, neuralInputs, rDoubleVector2, i);
            if (rBooleanVector2 != null && rBooleanVector2.asScalar().booleanValue()) {
                if (rBooleanVector3 != null && rBooleanVector3.asScalar().booleanValue()) {
                    throw new IllegalArgumentException();
                }
                encodeNeuralLayer2.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
            }
            int size2 = i + (asInt2 * (neuralInputs.size() + 1));
            arrayList.add(encodeNeuralLayer2);
            neurons = encodeNeuralLayer2.getNeurons();
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema(label), createNeuralInputs, arrayList);
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
                neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(neurons, label));
                break;
            case 2:
                neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(neurons, (CategoricalLabel) label)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
                break;
        }
        return neuralNetwork;
    }

    private static NeuralLayer encodeNeuralLayer(String str, int i, List<? extends Entity> list, RDoubleVector rDoubleVector, int i2) {
        NeuralLayer neuralLayer = new NeuralLayer();
        for (int i3 = 0; i3 < i; i3++) {
            neuralLayer.addNeurons(new Neuron[]{NeuralNetworkUtil.createNeuron(list, rDoubleVector.getValues().subList(i2 + 1, i2 + list.size() + 1), rDoubleVector.getValue(i2)).setId(str + "/" + String.valueOf(i3 + 1))});
            i2 += list.size() + 1;
        }
        return neuralLayer;
    }

    private static NeuralLayer encodeLogisticTransform(Entity entity) {
        NeuralLayer activationFunction = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        activationFunction.addNeurons(new Neuron[]{new Neuron().setId("logistic/1").setBias(Double.valueOf(0.0d)).addConnections(new Connection[]{new Connection(entity.getId(), 1.0d)})});
        return activationFunction;
    }

    private static NeuralLayer encodeLabelBinarizerTransform(Entity entity) {
        NeuralLayer neuralLayer = new NeuralLayer();
        neuralLayer.addNeurons(new Neuron[]{new Neuron().setId("event/false").setBias(Double.valueOf(1.0d)).addConnections(new Connection[]{new Connection(entity.getId(), -1.0d)}), new Neuron().setId("event/true").setBias(Double.valueOf(0.0d)).addConnections(new Connection[]{new Connection(entity.getId(), 1.0d)})});
        return neuralLayer;
    }

    private static Neuron getOnlyNeuron(NeuralLayer neuralLayer) {
        return (Neuron) Iterables.getOnlyElement(neuralLayer.getNeurons());
    }
}
