package sklearn.neural_network;

import com.google.common.collect.Iterables;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.HasArray;

/* loaded from: input_file:sklearn/neural_network/NeuralNetworkUtil.class */
public class NeuralNetworkUtil {

    /* renamed from: sklearn.neural_network.NeuralNetworkUtil$1, reason: invalid class name */
    /* loaded from: input_file:sklearn/neural_network/NeuralNetworkUtil$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private NeuralNetworkUtil() {
    }

    public static int getNumberOfFeatures(List<? extends HasArray> list) {
        int[] arrayShape = list.get(0).getArrayShape();
        if (arrayShape.length != 2) {
            throw new IllegalArgumentException();
        }
        return arrayShape[0];
    }

    /* JADX WARN: Code restructure failed: missing block: B:41:0x023a, code lost:
    
        continue;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static org.dmg.pmml.neural_network.NeuralNetwork encodeNeuralNetwork(org.dmg.pmml.MiningFunction r8, java.lang.String r9, java.util.List<? extends org.jpmml.sklearn.HasArray> r10, java.util.List<? extends org.jpmml.sklearn.HasArray> r11, org.jpmml.converter.Schema r12) {
        /*
            Method dump skipped, instructions count: 685
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: sklearn.neural_network.NeuralNetworkUtil.encodeNeuralNetwork(org.dmg.pmml.MiningFunction, java.lang.String, java.util.List, java.util.List, org.jpmml.converter.Schema):org.dmg.pmml.neural_network.NeuralNetwork");
    }

    private static NeuralLayer encodeLogisticTransform(Neuron neuron) {
        NeuralLayer activationFunction = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        activationFunction.addNeurons(new Neuron[]{new Neuron().setId("logistic/1").setBias(Double.valueOf(0.0d)).addConnections(new Connection[]{new Connection(neuron.getId(), 1.0d)})});
        return activationFunction;
    }

    private static NeuralLayer encodeLabelBinarizerTransform(Neuron neuron) {
        NeuralLayer activationFunction = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        activationFunction.addNeurons(new Neuron[]{new Neuron().setId("event/false").setBias(Double.valueOf(1.0d)).addConnections(new Connection[]{new Connection(neuron.getId(), -1.0d)}), new Neuron().setId("event/true").setBias(Double.valueOf(0.0d)).addConnections(new Connection[]{new Connection(neuron.getId(), 1.0d)})});
        return activationFunction;
    }

    private static NeuralOutputs encodeRegressionNeuralOutputs(List<? extends Entity> list, Schema schema) {
        ContinuousLabel label = schema.getLabel();
        ClassDictUtil.checkSize(1, list);
        return new NeuralOutputs().addNeuralOutputs(new NeuralOutput[]{new NeuralOutput().setOutputNeuron(((Entity) Iterables.getOnlyElement(list)).getId()).setDerivedField(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setExpression(new FieldRef(label.getName())))});
    }

    private static NeuralOutputs encodeClassificationNeuralOutputs(List<? extends Entity> list, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        ClassDictUtil.checkSize(label.size(), list);
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int i = 0; i < label.size(); i++) {
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{new NeuralOutput().setOutputNeuron(list.get(i).getId()).setDerivedField(new DerivedField(OpType.CATEGORICAL, DataType.STRING).setExpression(new NormDiscrete(label.getName(), label.getValue(i))))});
        }
        return neuralOutputs;
    }

    private static void connect(Entity entity, List<Neuron> list, List<?> list2) {
        ClassDictUtil.checkSize(list, list2);
        for (int i = 0; i < list.size(); i++) {
            list.get(i).addConnections(new Connection[]{new Connection(entity.getId(), ValueUtil.asDouble((Number) list2.get(i)).doubleValue())});
        }
    }

    private static Neuron getOnlyNeuron(NeuralLayer neuralLayer) {
        return (Neuron) Iterables.getOnlyElement(neuralLayer.getNeurons());
    }

    private static NeuralNetwork.ActivationFunction parseActivationFunction(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -135761730:
                if (str.equals("identity")) {
                    z = false;
                    break;
                }
                break;
            case 3496700:
                if (str.equals("relu")) {
                    z = 2;
                    break;
                }
                break;
            case 3552487:
                if (str.equals("tanh")) {
                    z = 3;
                    break;
                }
                break;
            case 2022928992:
                if (str.equals("logistic")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return NeuralNetwork.ActivationFunction.IDENTITY;
            case true:
                return NeuralNetwork.ActivationFunction.LOGISTIC;
            case true:
                return NeuralNetwork.ActivationFunction.RECTIFIER;
            case true:
                return NeuralNetwork.ActivationFunction.TANH;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
