package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Classifier;
import sklearn.EstimatorUtil;
import sklearn.tree.DecisionTreeRegressor;

/* loaded from: input_file:sklearn/ensemble/gradient_boosting/GradientBoostingClassifier.class */
public class GradientBoostingClassifier extends Classifier {
    public GradientBoostingClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return ValueUtil.asInt((Number) get("n_features"));
    }

    @Override // sklearn.Estimator
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo20encodeModel(Schema schema) {
        LossFunction loss = getLoss();
        int intValue = loss.getK().intValue();
        HasPriorProbability init = getInit();
        Number learningRate = getLearningRate();
        List<DecisionTreeRegressor> estimators = getEstimators();
        Schema anonymousSchema = schema.toAnonymousSchema();
        CategoricalLabel label = anonymousSchema.getLabel();
        if (intValue == 1) {
            EstimatorUtil.checkSize(2, label);
            return MiningModelUtil.createBinaryLogisticClassification(schema, encodeCategoryRegressor(label.getValue(1), estimators, init.getPriorProbability(0), learningRate, null, anonymousSchema), loss.getCoefficient(), true);
        }
        if (intValue < 2) {
            throw new IllegalArgumentException();
        }
        EstimatorUtil.checkSize(intValue, label);
        ArrayList arrayList = new ArrayList();
        int size = label.size();
        int size2 = estimators.size() / size;
        for (int i = 0; i < size; i++) {
            arrayList.add(encodeCategoryRegressor(label.getValue(i), MatrixUtil.getColumn(estimators, size2, size, i), init.getPriorProbability(i), learningRate, loss.getFunction(), anonymousSchema));
        }
        return MiningModelUtil.createClassification(schema, arrayList, RegressionModel.NormalizationMethod.SIMPLEMAX, true);
    }

    @Override // sklearn.Estimator
    public Set<DefineFunction> encodeDefineFunctions() {
        DefineFunction encodeFunction = getLoss().encodeFunction();
        return encodeFunction != null ? Collections.singleton(encodeFunction) : super.encodeDefineFunctions();
    }

    public LossFunction getLoss() {
        Object obj = get("loss_");
        try {
            if (obj == null) {
                throw new NullPointerException();
            }
            return (LossFunction) obj;
        } catch (RuntimeException e) {
            throw new IllegalArgumentException("The loss function object (" + ClassDictUtil.formatClass(obj) + ") is not a LossFunction or is not a supported LossFunction subclass", e);
        }
    }

    public HasPriorProbability getInit() {
        Object obj = get("init_");
        try {
            if (obj == null) {
                throw new NullPointerException();
            }
            return (HasPriorProbability) obj;
        } catch (RuntimeException e) {
            throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass(obj) + ") is not a BaseEstimator or is not a supported BaseEstimator subclass", e);
        }
    }

    public Number getLearningRate() {
        return (Number) get("learning_rate");
    }

    public List<DecisionTreeRegressor> getEstimators() {
        return ClassDictUtil.getArray(this, "estimators_");
    }

    private static MiningModel encodeCategoryRegressor(String str, List<DecisionTreeRegressor> list, Number number, Number number2, String str2, Schema schema) {
        OutputField finalResult = new OutputField(FieldName.create("decisionFunction_" + str), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.PREDICTED_VALUE).setFinalResult(false);
        Output addOutputFields = new Output().addOutputFields(new OutputField[]{finalResult});
        if (str2 != null) {
            addOutputFields.addOutputFields(new OutputField[]{new OutputField(FieldName.create(str2 + "DecisionFunction_" + str), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(false).setExpression(PMMLUtil.createApply(str2, new Expression[]{new FieldRef(finalResult.getName())}))});
        }
        return GradientBoostingUtil.encodeGradientBoosting(list, number, number2, schema).setOutput(addOutputFields);
    }
}
