package org.jpmml.sparkml.feature;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.ml.feature.OneHotEncoderModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.sparkml.BinarizedCategoricalFeature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.MultiFeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

/* loaded from: input_file:org/jpmml/sparkml/feature/OneHotEncoderModelConverter.class */
public class OneHotEncoderModelConverter extends MultiFeatureConverter<OneHotEncoderModel> {
    public OneHotEncoderModelConverter(OneHotEncoderModel oneHotEncoderModel) {
        super(oneHotEncoderModel);
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public List<Feature> encodeFeatures(SparkMLEncoder sparkMLEncoder) {
        OneHotEncoderModel oneHotEncoderModel = (OneHotEncoderModel) getTransformer();
        boolean dropLast = oneHotEncoderModel.getDropLast();
        FeatureConverter.InOutMode inputMode = getInputMode();
        ArrayList arrayList = new ArrayList();
        for (String str : inputMode.getInputCols(oneHotEncoderModel)) {
            CategoricalFeature onlyFeature = sparkMLEncoder.getOnlyFeature(str);
            arrayList.add(new BinarizedCategoricalFeature(sparkMLEncoder, onlyFeature, encodeFeature(sparkMLEncoder, onlyFeature, onlyFeature.getValues(), dropLast)));
        }
        return arrayList;
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public void registerFeatures(SparkMLEncoder sparkMLEncoder) {
        OneHotEncoderModel oneHotEncoderModel = (OneHotEncoderModel) getTransformer();
        List<Feature> encodeFeatures = encodeFeatures(sparkMLEncoder);
        FeatureConverter.InOutMode outputMode = getOutputMode();
        if (FeatureConverter.InOutMode.SINGLE.equals(outputMode)) {
            sparkMLEncoder.putFeatures(oneHotEncoderModel.getOutputCol(), ((BinarizedCategoricalFeature) Iterables.getOnlyElement(encodeFeatures)).getBinaryFeatures());
            return;
        }
        if (FeatureConverter.InOutMode.MULTIPLE.equals(outputMode)) {
            String[] outputCols = oneHotEncoderModel.getOutputCols();
            if (outputCols.length != encodeFeatures.size()) {
                throw new IllegalArgumentException("Expected " + outputCols.length + " features, got " + encodeFeatures.size() + " features");
            }
            for (int i = 0; i < outputCols.length; i++) {
                sparkMLEncoder.putFeatures(outputCols[i], ((BinarizedCategoricalFeature) encodeFeatures.get(i)).getBinaryFeatures());
            }
        }
    }

    public static List<BinaryFeature> encodeFeature(PMMLEncoder pMMLEncoder, Feature feature, List<?> list, boolean z) {
        ArrayList arrayList = new ArrayList();
        if (z) {
            list = list.subList(0, list.size() - 1);
        }
        Iterator<?> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new BinaryFeature(pMMLEncoder, feature, it.next()));
        }
        return arrayList;
    }
}
