package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/rexp/BinaryTreeConverter.class */
public class BinaryTreeConverter extends TreeModelConverter<S4Object> {
    private MiningFunction miningFunction;
    private Map<FieldName, Integer> featureIndexes;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.rexp.BinaryTreeConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/rexp/BinaryTreeConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public BinaryTreeConverter(S4Object s4Object) {
        super(s4Object);
        this.miningFunction = null;
        this.featureIndexes = new LinkedHashMap();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        S4Object s4Object = (S4Object) getObject();
        S4Object s4Object2 = (S4Object) s4Object.getAttributeValue("responses");
        RGenericVector rGenericVector = (RGenericVector) s4Object.getAttributeValue("tree");
        encodeResponse(s4Object2, rExpEncoder);
        encodeVariableList(rGenericVector, rExpEncoder);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public TreeModel mo0encodeModel(Schema schema) {
        Output createProbabilityOutput;
        RGenericVector rGenericVector = (RGenericVector) ((S4Object) getObject()).getAttributeValue("tree");
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[this.miningFunction.ordinal()]) {
            case 1:
                createProbabilityOutput = new Output();
                break;
            case 2:
                createProbabilityOutput = ModelUtil.createProbabilityOutput(DataType.DOUBLE, schema.getLabel());
                break;
            default:
                throw new IllegalArgumentException();
        }
        createProbabilityOutput.addOutputFields(new OutputField[]{ModelUtil.createEntityIdField(FieldName.create("nodeId"))});
        return encodeTreeModel(rGenericVector, schema).setOutput(createProbabilityOutput);
    }

    private void encodeResponse(S4Object s4Object, RExpEncoder rExpEncoder) {
        DataField createDataField;
        RGenericVector rGenericVector = (RGenericVector) s4Object.getAttributeValue("variables");
        RBooleanVector rBooleanVector = (RBooleanVector) s4Object.getAttributeValue("is_nominal");
        RGenericVector rGenericVector2 = (RGenericVector) s4Object.getAttributeValue("levels");
        String asScalar = rGenericVector.names().asScalar();
        Boolean value = rBooleanVector.getValue(asScalar);
        if (Boolean.TRUE.equals(value)) {
            this.miningFunction = MiningFunction.CLASSIFICATION;
            createDataField = rExpEncoder.createDataField(FieldName.create(asScalar), OpType.CATEGORICAL, RExpUtil.getDataType(RExpUtil.getClassNames(rGenericVector.getValue(asScalar)).asScalar()), ((RStringVector) rGenericVector2.getValue(asScalar)).getValues());
        } else {
            if (!Boolean.FALSE.equals(value)) {
                throw new IllegalArgumentException();
            }
            this.miningFunction = MiningFunction.REGRESSION;
            createDataField = rExpEncoder.createDataField(FieldName.create(asScalar), OpType.CONTINUOUS, DataType.DOUBLE);
        }
        rExpEncoder.setLabel(createDataField);
    }

    private void encodeVariableList(RGenericVector rGenericVector, RExpEncoder rExpEncoder) {
        DataField createDataField;
        RBooleanVector rBooleanVector = (RBooleanVector) rGenericVector.getValue("terminal");
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("psplit");
        RGenericVector rGenericVector3 = (RGenericVector) rGenericVector.getValue("left");
        RGenericVector rGenericVector4 = (RGenericVector) rGenericVector.getValue("right");
        if (Boolean.TRUE.equals(rBooleanVector.asScalar())) {
            return;
        }
        RNumberVector rNumberVector = (RNumberVector) rGenericVector2.getValue("splitpoint");
        FieldName create = FieldName.create(((RStringVector) rGenericVector2.getValue("variableName")).asScalar());
        if (rExpEncoder.getDataField(create) == null) {
            if (rNumberVector instanceof RIntegerVector) {
                createDataField = rExpEncoder.createDataField(create, OpType.CATEGORICAL, null, ((RStringVector) rNumberVector.getAttributeValue("levels")).getValues());
            } else {
                if (!(rNumberVector instanceof RDoubleVector)) {
                    throw new IllegalArgumentException();
                }
                createDataField = rExpEncoder.createDataField(create, OpType.CONTINUOUS, DataType.DOUBLE);
            }
            rExpEncoder.addFeature((Field<?>) createDataField);
            this.featureIndexes.put(create, Integer.valueOf(this.featureIndexes.size()));
        }
        encodeVariableList(rGenericVector3, rExpEncoder);
        encodeVariableList(rGenericVector4, rExpEncoder);
    }

    private TreeModel encodeTreeModel(RGenericVector rGenericVector, Schema schema) {
        Node predicate = new Node().setPredicate(new True());
        encodeNode(predicate, rGenericVector, schema);
        return new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void encodeNode(Node node, RGenericVector rGenericVector, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        RIntegerVector rIntegerVector = (RIntegerVector) rGenericVector.getValue("nodeID");
        RBooleanVector rBooleanVector = (RBooleanVector) rGenericVector.getValue("terminal");
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("psplit");
        RGenericVector rGenericVector3 = (RGenericVector) rGenericVector.getValue("ssplits");
        RDoubleVector rDoubleVector = (RDoubleVector) rGenericVector.getValue("prediction");
        RGenericVector rGenericVector4 = (RGenericVector) rGenericVector.getValue("left");
        RGenericVector rGenericVector5 = (RGenericVector) rGenericVector.getValue("right");
        node.setId(String.valueOf(rIntegerVector.asScalar()));
        if (Boolean.TRUE.equals(rBooleanVector.asScalar())) {
            encodeScore(node, rDoubleVector, schema);
            return;
        }
        RNumberVector rNumberVector = (RNumberVector) rGenericVector2.getValue("splitpoint");
        RStringVector rStringVector = (RStringVector) rGenericVector2.getValue("variableName");
        if (rGenericVector3.size() > 0) {
            throw new IllegalArgumentException();
        }
        Integer num = this.featureIndexes.get(FieldName.create(rStringVector.asScalar()));
        if (num == null) {
            throw new IllegalArgumentException();
        }
        CategoricalFeature feature = schema.getFeature(num.intValue());
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = feature;
            List values = categoricalFeature.getValues();
            List<V> values2 = rNumberVector.getValues();
            createSimplePredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, values2, true));
            createSimplePredicate2 = createSimpleSetPredicate(categoricalFeature, selectValues(values, values2, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String formatValue = ValueUtil.formatValue((Double) rNumberVector.asScalar());
            createSimplePredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, formatValue);
            createSimplePredicate2 = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, formatValue);
        }
        Node predicate = new Node().setPredicate(createSimplePredicate);
        encodeNode(predicate, rGenericVector4, schema);
        Node predicate2 = new Node().setPredicate(createSimplePredicate2);
        encodeNode(predicate2, rGenericVector5, schema);
        node.addNodes(new Node[]{predicate, predicate2});
    }

    private Node encodeScore(Node node, RDoubleVector rDoubleVector, Schema schema) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[this.miningFunction.ordinal()]) {
            case 1:
                return encodeRegressionScore(node, rDoubleVector);
            case 2:
                return encodeClassificationScore(node, rDoubleVector, schema);
            default:
                throw new IllegalArgumentException();
        }
    }

    private static <E> List<E> selectValues(List<E> list, List<Integer> list2, boolean z) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            E e = list.get(i);
            Integer num = list2.get(i);
            if (z ? num.intValue() == 1 : num.intValue() == 0) {
                arrayList.add(e);
            }
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Node encodeRegressionScore(Node node, RDoubleVector rDoubleVector) {
        if (rDoubleVector.size() != 1) {
            throw new IllegalArgumentException();
        }
        node.setScore(ValueUtil.formatValue((Double) rDoubleVector.asScalar()));
        return node;
    }

    private static Node encodeClassificationScore(Node node, RDoubleVector rDoubleVector, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        if (label.size() != rDoubleVector.size()) {
            throw new IllegalArgumentException();
        }
        Double d = null;
        for (int i = 0; i < label.size(); i++) {
            String value = label.getValue(i);
            Double value2 = rDoubleVector.getValue(i);
            if (d == null || d.compareTo(value2) < 0) {
                node.setScore(value);
                d = value2;
            }
            node.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution(value, value2.doubleValue())});
        }
        return node;
    }
}
