package sklearn.preprocessing;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

/* loaded from: input_file:sklearn/preprocessing/Imputer.class */
public class Imputer extends Transformer {
    public Imputer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<String> list, List<Feature> list2, SkLearnEncoder skLearnEncoder) {
        List<? extends Number> statistics = getStatistics();
        ClassDictUtil.checkSize(list, list2, statistics);
        Number targetValue = getTargetValue(getMissingValues());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            String str = list.get(i);
            Feature feature = list2.get(i);
            Number number = statistics.get(i);
            if (skLearnEncoder.getField(feature.getName()) instanceof DataField) {
                MissingValueDecorator missingValueTreatment = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(number)).setMissingValueTreatment(parseStrategy(getStrategy()));
                if (targetValue != null) {
                    missingValueTreatment.addMissingValues(new String[]{ValueUtil.formatValue(targetValue)});
                }
                skLearnEncoder.addDecorator(feature.getName(), missingValueTreatment);
                arrayList.add(feature);
            } else {
                Expression ref = feature.ref();
                arrayList.add(new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(createName(str), PMMLUtil.createApply("if", new Expression[]{targetValue == null ? PMMLUtil.createApply("isMissing", new Expression[]{ref}) : PMMLUtil.createApply("equal", new Expression[]{ref, PMMLUtil.createConstant(targetValue)}), PMMLUtil.createConstant(number), feature.ref()}))));
            }
        }
        return arrayList;
    }

    public Object getMissingValues() {
        return get("missing_values");
    }

    public List<? extends Number> getStatistics() {
        return ClassDictUtil.getArray(this, "statistics_");
    }

    public String getStrategy() {
        return (String) get("strategy");
    }

    private static Number getTargetValue(Object obj) {
        if (obj instanceof String) {
            return null;
        }
        if (obj instanceof Number) {
            return (Number) obj;
        }
        throw new IllegalArgumentException();
    }

    private static MissingValueTreatmentMethod parseStrategy(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1078031094:
                if (str.equals("median")) {
                    z = true;
                    break;
                }
                break;
            case 3347397:
                if (str.equals("mean")) {
                    z = false;
                    break;
                }
                break;
            case 574622730:
                if (str.equals("most_frequent")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MissingValueTreatmentMethod.AS_MEAN;
            case true:
                return MissingValueTreatmentMethod.AS_MEDIAN;
            case true:
                return MissingValueTreatmentMethod.AS_MODE;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
