package tpot.builtins;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasEstimator;
import sklearn.Transformer;

/* loaded from: input_file:tpot/builtins/StackingEstimator.class */
public class StackingEstimator extends Transformer implements HasEstimator<Estimator> {

    /* renamed from: tpot.builtins.StackingEstimator$1, reason: invalid class name */
    /* loaded from: input_file:tpot/builtins/StackingEstimator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public StackingEstimator(String str, String str2) {
        super(str, str2);
    }

    public int getNumberOfFeatures() {
        return getEstimator().getNumberOfFeatures();
    }

    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        CategoricalLabel continuousLabel;
        Classifier estimator = getEstimator();
        MiningFunction miningFunction = estimator.getMiningFunction();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
                List classes = EstimatorUtil.getClasses(estimator);
                continuousLabel = new CategoricalLabel(TypeUtil.getDataType(classes, DataType.STRING), classes);
                break;
            case 2:
                continuousLabel = new ContinuousLabel(DataType.DOUBLE);
                break;
            default:
                throw new IllegalArgumentException();
        }
        Model encode = estimator.encode(new Schema(skLearnEncoder, continuousLabel, list));
        skLearnEncoder.addTransformer(encode);
        String createFieldName = createFieldName("stack", list);
        ArrayList arrayList = new ArrayList();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
            case 2:
                arrayList.add(skLearnEncoder.exportPrediction(encode, createFieldName, continuousLabel));
                switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                    case 1:
                        if (estimator.hasProbabilityDistribution()) {
                            for (Object obj : EstimatorUtil.getClasses(estimator)) {
                                arrayList.add(skLearnEncoder.exportProbability(encode, FieldNameUtil.create("probability", new Object[]{createFieldName, obj}), obj));
                            }
                            break;
                        }
                        break;
                    case 2:
                        break;
                    default:
                        throw new IllegalArgumentException();
                }
                arrayList.addAll(list);
                return arrayList;
            default:
                throw new IllegalArgumentException();
        }
    }

    public Estimator getEstimator() {
        return (Estimator) get("estimator", Estimator.class);
    }
}
