package sklearn.impute;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.dmg.pmml.OpType;
import org.jpmml.converter.Feature;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.StepUtil;
import sklearn.Transformer;

/* loaded from: input_file:sklearn/impute/SimpleImputer.class */
public class SimpleImputer extends Transformer {
    public SimpleImputer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer, sklearn.HasType
    public OpType getOpType() {
        String strategy = getStrategy();
        boolean z = -1;
        switch (strategy.hashCode()) {
            case -1078031094:
                if (strategy.equals("median")) {
                    z = 2;
                    break;
                }
                break;
            case -567811164:
                if (strategy.equals("constant")) {
                    z = false;
                    break;
                }
                break;
            case 3347397:
                if (strategy.equals("mean")) {
                    z = true;
                    break;
                }
                break;
            case 574622730:
                if (strategy.equals("most_frequent")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return StepUtil.getOpType(getDataType());
            case true:
            case true:
                return OpType.CONTINUOUS;
            case true:
                return OpType.CATEGORICAL;
            default:
                throw new IllegalArgumentException(strategy);
        }
    }

    @Override // sklearn.Transformer, sklearn.HasType
    public DataType getDataType() {
        String strategy = getStrategy();
        List<?> statistics = getStatistics();
        boolean z = -1;
        switch (strategy.hashCode()) {
            case -1078031094:
                if (strategy.equals("median")) {
                    z = 2;
                    break;
                }
                break;
            case -567811164:
                if (strategy.equals("constant")) {
                    z = false;
                    break;
                }
                break;
            case 3347397:
                if (strategy.equals("mean")) {
                    z = true;
                    break;
                }
                break;
            case 574622730:
                if (strategy.equals("most_frequent")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return TypeUtil.getDataType(statistics, DataType.STRING);
            case true:
            case true:
                return DataType.DOUBLE;
            case true:
                return TypeUtil.getDataType(statistics, DataType.STRING);
            default:
                throw new IllegalArgumentException(strategy);
        }
    }

    @Override // sklearn.Transformer, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return getStatisticsShape()[0];
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        Boolean addIndicator = getAddIndicator();
        Object missingValues = getMissingValues();
        List<?> statistics = getStatistics();
        String strategy = getStrategy();
        ClassDictUtil.checkSize(new Collection[]{list, statistics});
        if (ValueUtil.isNaN(missingValues)) {
            missingValues = null;
        }
        MissingValueTreatmentMethod parseStrategy = parseStrategy(strategy);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            Object obj = statistics.get(i);
            if (addIndicator.booleanValue()) {
                arrayList.add(ImputerUtil.encodeIndicatorFeature(this, feature, missingValues, skLearnEncoder));
            }
            arrayList2.add(ImputerUtil.encodeFeature(this, feature, addIndicator, missingValues, obj, parseStrategy, skLearnEncoder));
        }
        if (addIndicator.booleanValue()) {
            arrayList2.addAll(arrayList);
        }
        return arrayList2;
    }

    public Boolean getAddIndicator() {
        return getOptionalBoolean("add_indicator", Boolean.FALSE);
    }

    public Object getMissingValues() {
        return getOptionalScalar("missing_values");
    }

    public List<?> getStatistics() {
        return getArray("statistics_");
    }

    public int[] getStatisticsShape() {
        return getArrayShape("statistics_", 1);
    }

    public String getStrategy() {
        return getString("strategy");
    }

    private static MissingValueTreatmentMethod parseStrategy(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1078031094:
                if (str.equals("median")) {
                    z = 2;
                    break;
                }
                break;
            case -567811164:
                if (str.equals("constant")) {
                    z = false;
                    break;
                }
                break;
            case 3347397:
                if (str.equals("mean")) {
                    z = true;
                    break;
                }
                break;
            case 574622730:
                if (str.equals("most_frequent")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MissingValueTreatmentMethod.AS_VALUE;
            case true:
                return MissingValueTreatmentMethod.AS_MEAN;
            case true:
                return MissingValueTreatmentMethod.AS_MEDIAN;
            case true:
                return MissingValueTreatmentMethod.AS_MODE;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
