package org.jpmml.rexp;

import com.google.common.math.DoubleMath;
import com.google.common.primitives.UnsignedLong;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;

/* loaded from: input_file:org/jpmml/rexp/RandomForestConverter.class */
public class RandomForestConverter extends TreeModelConverter<RGenericVector> {
    private static final UnsignedLong TWO = UnsignedLong.valueOf(2);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/rexp/RandomForestConverter$ScoreEncoder.class */
    public interface ScoreEncoder<V extends Number> {
        String encode(V v);
    }

    public RandomForestConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        if (((RGenericVector) getObject()).getValue("terms", true) != null) {
            encodeFormula(rExpEncoder);
        } else {
            encodeNonFormula(rExpEncoder);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo0encodeModel(Schema schema) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RStringVector rStringVector = (RStringVector) rGenericVector.getValue("type");
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("forest");
        String asScalar = rStringVector.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case 382350310:
                if (asScalar.equals("classification")) {
                    z = true;
                    break;
                }
                break;
            case 1421312065:
                if (asScalar.equals("regression")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return encodeRegression(rGenericVector2, schema);
            case true:
                return encodeClassification(rGenericVector2, schema);
            default:
                throw new IllegalArgumentException();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void encodeFormula(RExpEncoder rExpEncoder) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("forest");
        RNumberVector rNumberVector = (RNumberVector) rGenericVector.getValue("y", true);
        RExp value = rGenericVector.getValue("terms");
        final RNumberVector rNumberVector2 = (RNumberVector) rGenericVector2.getValue("ncat");
        RGenericVector rGenericVector3 = (RGenericVector) rGenericVector2.getValue("xlevels");
        Formula createFormula = FormulaUtil.createFormula(value, new XLevelsFormulaContext(rGenericVector3) { // from class: org.jpmml.rexp.RandomForestConverter.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // org.jpmml.rexp.XLevelsFormulaContext, org.jpmml.rexp.FormulaContext
            public List<String> getCategories(String str) {
                if (rNumberVector2 == null || !rNumberVector2.hasValue(str) || ((Number) rNumberVector2.getValue(str)).doubleValue() <= 1.0d) {
                    return null;
                }
                return super.getCategories(str);
            }
        }, rExpEncoder);
        if (rNumberVector instanceof RIntegerVector) {
            SchemaUtil.setLabel(createFormula, value, rNumberVector, rExpEncoder);
        } else {
            SchemaUtil.setLabel(createFormula, value, null, rExpEncoder);
        }
        SchemaUtil.addFeatures(createFormula, rGenericVector3.names(), false, rExpEncoder);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void encodeNonFormula(RExpEncoder rExpEncoder) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector rGenericVector2 = (RGenericVector) rGenericVector.getValue("forest");
        RNumberVector rNumberVector = (RNumberVector) rGenericVector.getValue("y", true);
        RStringVector rStringVector = (RStringVector) rGenericVector.getValue("xNames", true);
        RNumberVector rNumberVector2 = (RNumberVector) rGenericVector2.getValue("ncat");
        RGenericVector rGenericVector3 = (RGenericVector) rGenericVector2.getValue("xlevels");
        if (rStringVector == null) {
            rStringVector = rGenericVector3.names();
        }
        FieldName create = FieldName.create("_target");
        rExpEncoder.setLabel(rNumberVector instanceof RIntegerVector ? rExpEncoder.createDataField(create, OpType.CATEGORICAL, null, RExpUtil.getFactorLevels((RNumberVector<?>) rNumberVector)) : rExpEncoder.createDataField(create, OpType.CONTINUOUS, DataType.DOUBLE));
        for (int i = 0; i < rNumberVector2.size(); i++) {
            FieldName create2 = FieldName.create(rStringVector.getValue(i));
            rExpEncoder.addFeature((Field<?>) ((((Number) rNumberVector2.getValue(i)).doubleValue() > 1.0d ? 1 : (((Number) rNumberVector2.getValue(i)).doubleValue() == 1.0d ? 0 : -1)) > 0 ? rExpEncoder.createDataField(create2, OpType.CATEGORICAL, null, ((RStringVector) rGenericVector3.getValue(i)).getValues()) : rExpEncoder.createDataField(create2, OpType.CONTINUOUS, DataType.DOUBLE)));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private MiningModel encodeRegression(RGenericVector rGenericVector, Schema schema) {
        RNumberVector rNumberVector = (RNumberVector) rGenericVector.getValue("leftDaughter");
        RNumberVector rNumberVector2 = (RNumberVector) rGenericVector.getValue("rightDaughter");
        RDoubleVector rDoubleVector = (RDoubleVector) rGenericVector.getValue("nodepred");
        RNumberVector rNumberVector3 = (RNumberVector) rGenericVector.getValue("bestvar");
        RDoubleVector rDoubleVector2 = (RDoubleVector) rGenericVector.getValue("xbestsplit");
        RIntegerVector rIntegerVector = (RIntegerVector) rGenericVector.getValue("nrnodes");
        RDoubleVector rDoubleVector3 = (RDoubleVector) rGenericVector.getValue("ntree");
        ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>() { // from class: org.jpmml.rexp.RandomForestConverter.2
            @Override // org.jpmml.rexp.RandomForestConverter.ScoreEncoder
            public String encode(Double d) {
                return ValueUtil.formatValue(d);
            }
        };
        int intValue = ((Integer) rIntegerVector.asScalar()).intValue();
        int asInt = ValueUtil.asInt((Number) rDoubleVector3.asScalar());
        Schema anonymousSchema = schema.toAnonymousSchema();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < asInt; i++) {
            arrayList.add(encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(rNumberVector.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rNumberVector2.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rDoubleVector.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rNumberVector3.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rDoubleVector2.getValues(), intValue, asInt, i), anonymousSchema));
        }
        return new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, arrayList));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private MiningModel encodeClassification(RGenericVector rGenericVector, Schema schema) {
        RNumberVector rNumberVector = (RNumberVector) rGenericVector.getValue("bestvar");
        RNumberVector rNumberVector2 = (RNumberVector) rGenericVector.getValue("treemap");
        RIntegerVector rIntegerVector = (RIntegerVector) rGenericVector.getValue("nodepred");
        RDoubleVector rDoubleVector = (RDoubleVector) rGenericVector.getValue("xbestsplit");
        RIntegerVector rIntegerVector2 = (RIntegerVector) rGenericVector.getValue("nrnodes");
        RDoubleVector rDoubleVector2 = (RDoubleVector) rGenericVector.getValue("ntree");
        int intValue = ((Integer) rIntegerVector2.asScalar()).intValue();
        int asInt = ValueUtil.asInt((Number) rDoubleVector2.asScalar());
        final CategoricalLabel label = schema.getLabel();
        ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() { // from class: org.jpmml.rexp.RandomForestConverter.3
            @Override // org.jpmml.rexp.RandomForestConverter.ScoreEncoder
            public String encode(Integer num) {
                return label.getValue(num.intValue() - 1);
            }
        };
        Schema anonymousSchema = schema.toAnonymousSchema();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < asInt; i++) {
            List column = FortranMatrixUtil.getColumn(rNumberVector2.getValues(), 2 * intValue, asInt, i);
            arrayList.add(encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(column, intValue, 2, 0), FortranMatrixUtil.getColumn(column, intValue, 2, 1), FortranMatrixUtil.getColumn(rIntegerVector.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rNumberVector.getValues(), intValue, asInt, i), FortranMatrixUtil.getColumn(rDoubleVector.getValues(), intValue, asInt, i), anonymousSchema));
        }
        return new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, arrayList)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, label));
    }

    private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> list, List<? extends Number> list2, List<P> list3, List<? extends Number> list4, List<Double> list5, Schema schema) {
        Node predicate = new Node().setId("1").setPredicate(new True());
        encodeNode(predicate, 0, scoreEncoder, list, list2, list4, list5, list3, schema);
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> list, List<? extends Number> list2, List<? extends Number> list3, List<Double> list4, List<P> list5, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        int asInt = ValueUtil.asInt(list3.get(i));
        if (asInt == 0) {
            node.setScore(scoreEncoder.encode(list5.get(i)));
            return;
        }
        BooleanFeature feature = schema.getFeature(asInt - 1);
        Double d = list4.get(i);
        if (feature instanceof BooleanFeature) {
            BooleanFeature booleanFeature = feature;
            if (d.doubleValue() != 0.5d) {
                throw new IllegalArgumentException();
            }
            createSimplePredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
            createSimplePredicate2 = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List values = categoricalFeature.getValues();
            createSimplePredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, d, true));
            createSimplePredicate2 = createSimpleSetPredicate(categoricalFeature, selectValues(values, d, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String formatValue = ValueUtil.formatValue(d);
            createSimplePredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, formatValue);
            createSimplePredicate2 = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, formatValue);
        }
        int asInt2 = ValueUtil.asInt(list.get(i));
        if (asInt2 != 0) {
            Node predicate = new Node().setId(String.valueOf(asInt2)).setPredicate(createSimplePredicate);
            encodeNode(predicate, asInt2 - 1, scoreEncoder, list, list2, list3, list4, list5, schema);
            node.addNodes(new Node[]{predicate});
        }
        int asInt3 = ValueUtil.asInt(list2.get(i));
        if (asInt3 != 0) {
            Node predicate2 = new Node().setId(String.valueOf(asInt3)).setPredicate(createSimplePredicate2);
            encodeNode(predicate2, asInt3 - 1, scoreEncoder, list, list2, list3, list4, list5, schema);
            node.addNodes(new Node[]{predicate2});
        }
    }

    static <E> List<E> selectValues(List<E> list, Double d, boolean z) {
        UnsignedLong unsignedLong = toUnsignedLong(d.doubleValue());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            E e = list.get(i);
            if (z ? unsignedLong.mod(TWO).equals(UnsignedLong.ONE) : unsignedLong.mod(TWO).equals(UnsignedLong.ZERO)) {
                arrayList.add(e);
            }
            unsignedLong = unsignedLong.dividedBy(TWO);
        }
        return arrayList;
    }

    static UnsignedLong toUnsignedLong(double d) {
        if (DoubleMath.isMathematicalInteger(d)) {
            return UnsignedLong.fromLongBits((long) d);
        }
        throw new IllegalArgumentException();
    }
}
