package category_encoders;

import com.google.common.base.Functions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import numpy.core.ScalarUtil;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import pandas.core.BlockManager;
import pandas.core.DataFrame;
import pandas.core.Index;
import pandas.core.Series;
import pandas.core.SeriesUtil;
import pandas.core.SingleBlockManager;
import sklearn.preprocessing.EncoderUtil;

/* loaded from: input_file:category_encoders/MeanEncoder.class */
public abstract class MeanEncoder extends MapEncoder {

    /* loaded from: input_file:category_encoders/MeanEncoder$MeanFunction.class */
    public interface MeanFunction extends BiFunction<Double, Integer, Double> {
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.function.BiFunction
        Double apply(Double d, Integer num);
    }

    public MeanEncoder(String str, String str2) {
        super(str, str2);
    }

    public abstract MeanFunction createFunction();

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        List<?> cols = getCols();
        Boolean dropInvariant = getDropInvariant();
        String handleMissing = getHandleMissing();
        String handleUnknown = getHandleUnknown();
        Map<Object, Series> mapping = getMapping();
        if (dropInvariant.booleanValue()) {
            throw new IllegalArgumentException();
        }
        Object obj = null;
        boolean z = -1;
        switch (handleMissing.hashCode()) {
            case 96784904:
                if (handleMissing.equals("error")) {
                    z = false;
                    break;
                }
                break;
            case 111972721:
                if (handleMissing.equals("value")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                break;
            case true:
                obj = CategoryEncoder.CATEGORY_NAN;
                break;
            default:
                throw new IllegalArgumentException(handleMissing);
        }
        boolean z2 = -1;
        switch (handleUnknown.hashCode()) {
            case 96784904:
                if (handleUnknown.equals("error")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < list.size(); i++) {
                    Feature feature = list.get(i);
                    Map map = SeriesUtil.toMap(mapping.get(cols.get(i)), Functions.identity(), ValueUtil::asDouble);
                    ArrayList arrayList2 = new ArrayList();
                    arrayList2.addAll(map.keySet());
                    skLearnEncoder.toCategorical(feature.getName(), EncoderUtil.filterCategories(arrayList2));
                    arrayList.add(new MapFeature(skLearnEncoder, feature, map, obj) { // from class: category_encoders.MeanEncoder.1
                        @Override // category_encoders.MapFeature
                        public FieldName getDerivedName() {
                            return MeanEncoder.this.createFieldName(MeanEncoder.this.functionName(), getName());
                        }
                    });
                }
                return arrayList;
            default:
                throw new IllegalArgumentException(handleUnknown);
        }
    }

    @Override // category_encoders.MapEncoder
    public Map<Object, Series> getMapping() {
        return CategoryEncoderUtil.toTransformedMap((Map) get("mapping", Map.class), obj -> {
            return ScalarUtil.decode(obj);
        }, obj2 -> {
            return toMeanSeries((DataFrame) obj2, createFunction());
        });
    }

    public Double getMean() {
        return ValueUtil.asDouble(getNumber("_mean"));
    }

    private static Series toMeanSeries(DataFrame dataFrame, MeanFunction meanFunction) {
        BlockManager data = dataFrame.getData();
        List axesArray = data.getAxesArray();
        if (axesArray.size() != 2) {
            throw new IllegalArgumentException();
        }
        List dataData = ((Index) axesArray.get(0)).getDataData();
        ((Index) axesArray.get(1)).getDataData();
        if (!Arrays.asList("sum", "count").equals(dataData)) {
            throw new IllegalArgumentException();
        }
        List blockValues = data.getBlockValues();
        if (blockValues.size() != 2) {
            throw new IllegalArgumentException();
        }
        List arrayContent = ((HasArray) blockValues.get(0)).getArrayContent();
        List arrayContent2 = ((HasArray) blockValues.get(1)).getArrayContent();
        final ArrayList arrayList = new ArrayList();
        for (int i = 0; i < arrayContent.size(); i++) {
            arrayList.add(meanFunction.apply(ValueUtil.asDouble((Number) arrayContent.get(i)), ValueUtil.asInteger((Number) arrayContent2.get(i))));
        }
        HasArray hasArray = new HasArray() { // from class: category_encoders.MeanEncoder.2
            public List<?> getArrayContent() {
                return arrayList;
            }

            public int[] getArrayShape() {
                return new int[]{arrayList.size()};
            }
        };
        SingleBlockManager singleBlockManager = new SingleBlockManager();
        singleBlockManager.setOnlyBlockItem((Index) axesArray.get(1));
        singleBlockManager.setOnlyBlockValue(hasArray);
        Series series = new Series();
        series.setBlockManager(singleBlockManager);
        return series;
    }
}
