package sklearn.ensemble.voting;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Classifier;
import sklearn.HasEstimatorEnsemble;
import sklearn.StepUtil;

/* loaded from: input_file:sklearn/ensemble/voting/VotingClassifier.class */
public class VotingClassifier extends Classifier implements HasEstimatorEnsemble<Classifier> {
    public VotingClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return StepUtil.getNumberOfFeatures(getEstimators());
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo1encodeModel(Schema schema) {
        List<? extends Classifier> estimators = getEstimators();
        List<? extends Number> weights = getWeights();
        CategoricalLabel label = schema.getLabel();
        ArrayList arrayList = new ArrayList();
        Iterator<? extends Classifier> it = estimators.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().encode(schema));
        }
        return new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label)).setSegmentation(MiningModelUtil.createSegmentation(parseVoting(getVoting(), weights != null && weights.size() > 0), arrayList, weights)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, label));
    }

    @Override // sklearn.HasEstimatorEnsemble
    public List<? extends Classifier> getEstimators() {
        return getList("estimators_", Classifier.class);
    }

    public String getVoting() {
        return getString("voting");
    }

    public List<? extends Number> getWeights() {
        Object optionalObject = getOptionalObject("weights");
        return (optionalObject == null || (optionalObject instanceof List)) ? (List) optionalObject : getNumberArray("weights");
    }

    private static Segmentation.MultipleModelMethod parseVoting(String str, boolean z) {
        boolean z2 = -1;
        switch (str.hashCode()) {
            case 3195115:
                if (str.equals("hard")) {
                    z2 = false;
                    break;
                }
                break;
            case 3535914:
                if (str.equals("soft")) {
                    z2 = true;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                return z ? Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE : Segmentation.MultipleModelMethod.MAJORITY_VOTE;
            case true:
                return z ? Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE : Segmentation.MultipleModelMethod.AVERAGE;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
