package sklearn.tree;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import sklearn.Estimator;

/* loaded from: input_file:sklearn/tree/TreeModelUtil.class */
public class TreeModelUtil {
    private TreeModelUtil() {
    }

    public static <E extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(List<E> list, final MiningFunction miningFunction, final Schema schema) {
        return new ArrayList(Lists.transform(list, new Function<E, TreeModel>() { // from class: sklearn.tree.TreeModelUtil.1
            /* JADX WARN: Incorrect types in method signature: (TE;)Lorg/dmg/pmml/tree/TreeModel; */
            public TreeModel apply(Estimator estimator) {
                return TreeModelUtil.encodeTreeModel(estimator, miningFunction, TreeModelUtil.toTreeModelSchema(schema.toAnonymousSchema(), estimator.getDataType()));
            }
        }));
    }

    public static <E extends Estimator & HasTree> TreeModel encodeTreeModel(E e, MiningFunction miningFunction, Schema schema) {
        Tree tree = e.getTree();
        int[] childrenLeft = tree.getChildrenLeft();
        int[] childrenRight = tree.getChildrenRight();
        int[] feature = tree.getFeature();
        double[] threshold = tree.getThreshold();
        double[] values = tree.getValues();
        Node predicate = new Node().setId("1").setPredicate(new True());
        encodeNode(predicate, 0, childrenLeft, childrenRight, feature, threshold, values, miningFunction, schema);
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema), predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    private static void encodeNode(Node node, int i, int[] iArr, int[] iArr2, int[] iArr3, double[] dArr, double[] dArr2, MiningFunction miningFunction, Schema schema) {
        SimplePredicate value;
        SimplePredicate value2;
        int i2 = iArr3[i];
        if (i2 >= 0) {
            BinaryFeature feature = schema.getFeature(i2);
            float f = (float) dArr[i];
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = feature;
                if (f < 0.0f || f > 1.0f) {
                    throw new IllegalArgumentException();
                }
                value = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.NOT_EQUAL).setValue(binaryFeature.getValue());
                value2 = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT);
                String formatValue = ValueUtil.formatValue(Float.valueOf(f));
                value = new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.LESS_OR_EQUAL).setValue(formatValue);
                value2 = new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.GREATER_THAN).setValue(formatValue);
            }
            int i3 = iArr[i];
            int i4 = iArr2[i];
            Node predicate = new Node().setId(String.valueOf(i3 + 1)).setPredicate(value);
            encodeNode(predicate, i3, iArr, iArr2, iArr3, dArr, dArr2, miningFunction, schema);
            Node predicate2 = new Node().setId(String.valueOf(i4 + 1)).setPredicate(value2);
            encodeNode(predicate2, i4, iArr, iArr2, iArr3, dArr, dArr2, miningFunction, schema);
            node.addNodes(new Node[]{predicate, predicate2});
            return;
        }
        if (!MiningFunction.CLASSIFICATION.equals(miningFunction)) {
            if (!MiningFunction.REGRESSION.equals(miningFunction)) {
                throw new IllegalArgumentException();
            }
            node.setScore(ValueUtil.formatValue(Double.valueOf(dArr2[i])));
            return;
        }
        CategoricalLabel label = schema.getLabel();
        double[] row = getRow(dArr2, iArr.length, label.size(), i);
        double d = 0.0d;
        for (double d2 : row) {
            d += d2;
        }
        node.setRecordCount(Double.valueOf(d));
        String str = null;
        Double d3 = null;
        for (int i5 = 0; i5 < label.size(); i5++) {
            ScoreDistribution scoreDistribution = new ScoreDistribution(label.getValue(i5), row[i5]);
            node.addScoreDistributions(new ScoreDistribution[]{scoreDistribution});
            double d4 = row[i5] / d;
            if (d3 == null || d3.compareTo(Double.valueOf(d4)) < 0) {
                str = scoreDistribution.getValue();
                d3 = Double.valueOf(d4);
            }
        }
        node.setScore(str);
    }

    public static Schema toTreeModelSchema(Schema schema, DataType dataType) {
        ArrayList arrayList = new ArrayList();
        for (BinaryFeature binaryFeature : schema.getFeatures()) {
            if (binaryFeature instanceof BinaryFeature) {
                arrayList.add(binaryFeature);
            } else {
                arrayList.add(binaryFeature.toContinuousFeature(dataType));
            }
        }
        return new Schema(schema.getLabel(), arrayList);
    }

    private static double[] getRow(double[] dArr, int i, int i2, int i3) {
        if (dArr.length != i * i2) {
            throw new IllegalArgumentException("Expected " + (i * i2) + " element(s), got " + dArr.length + " element(s)");
        }
        double[] dArr2 = new double[i2];
        System.arraycopy(dArr, i3 * i2, dArr2, 0, i2);
        return dArr2;
    }
}
