package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.feature.ImputerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.dmg.pmml.Value;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.MultiFeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

/* loaded from: input_file:org/jpmml/sparkml/feature/ImputerModelConverter.class */
public class ImputerModelConverter extends MultiFeatureConverter<ImputerModel> {
    public ImputerModelConverter(ImputerModel imputerModel) {
        super(imputerModel);
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public List<Feature> encodeFeatures(SparkMLEncoder sparkMLEncoder) {
        ImputerModel imputerModel = (ImputerModel) getTransformer();
        Double valueOf = Double.valueOf(imputerModel.getMissingValue());
        String strategy = imputerModel.getStrategy();
        Dataset surrogateDF = imputerModel.surrogateDF();
        MissingValueTreatmentMethod parseStrategy = parseStrategy(strategy);
        List collectAsList = surrogateDF.collectAsList();
        if (collectAsList.size() != 1) {
            throw new IllegalArgumentException();
        }
        Row row = (Row) collectAsList.get(0);
        FeatureConverter.InOutMode inputMode = getInputMode();
        ArrayList arrayList = new ArrayList();
        for (String str : inputMode.getInputCols(imputerModel)) {
            Feature onlyFeature = sparkMLEncoder.getOnlyFeature(str);
            Field field = onlyFeature.getField();
            if (!(field instanceof DataField)) {
                throw new IllegalArgumentException();
            }
            DataField dataField = (DataField) field;
            sparkMLEncoder.addDecorator(dataField, new MissingValueDecorator(parseStrategy, row.getAs(str)));
            if (valueOf != null && !valueOf.isNaN()) {
                PMMLUtil.addValues(dataField, Collections.singletonList(valueOf), Value.Property.MISSING);
            }
            arrayList.add(onlyFeature);
        }
        return arrayList;
    }

    public static MissingValueTreatmentMethod parseStrategy(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1078031094:
                if (str.equals("median")) {
                    z = true;
                    break;
                }
                break;
            case 3347397:
                if (str.equals("mean")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MissingValueTreatmentMethod.AS_MEAN;
            case true:
                return MissingValueTreatmentMethod.AS_MEDIAN;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
