package sklego.meta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.CastFunction;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;

/* loaded from: input_file:sklego/meta/OrdinalClassifier.class */
public class OrdinalClassifier extends Classifier {
    public OrdinalClassifier(String str, String str2) {
        super(str, str2);
    }

    public Model encodeModel(Schema schema) {
        Classifier classifier;
        Map<?, ? extends Classifier> estimators = getEstimators();
        Map<?, ?> estimatorCategories = getEstimatorCategories();
        SkLearnEncoder encoder = schema.getEncoder();
        OrdinalLabel label = schema.getLabel();
        schema.getFeatures();
        SchemaUtil.checkSize(estimators.size() + 1, label);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int size = label.size() - 1;
        for (int i = 0; i < size; i++) {
            Object value = label.getValue(i);
            if (estimatorCategories == null || estimatorCategories.isEmpty()) {
                classifier = estimators.get(value);
            } else {
                Object obj = estimatorCategories.get(value);
                if (obj == null) {
                    throw new IllegalArgumentException();
                }
                classifier = estimators.get(obj);
            }
            if (classifier == null) {
                throw new IllegalArgumentException();
            }
            if (!classifier.hasProbabilityDistribution()) {
                throw new IllegalArgumentException();
            }
            CategoricalLabel categoricalLabel = new CategoricalLabel(DataType.DOUBLE, Arrays.asList("<=" + ValueUtil.asString(value), ">" + ValueUtil.asString(value)));
            Model encode = classifier.encode(schema.toRelabeledSchema(categoricalLabel));
            List export = encoder.export(encode, FieldNameUtil.create("probability", new Object[]{categoricalLabel.getValue(1)}));
            if (export.size() != 1) {
                throw new IllegalArgumentException();
            }
            arrayList.add(encode);
            arrayList2.addAll(export);
        }
        SchemaUtil.checkSize(estimators.size(), arrayList2);
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < estimators.size(); i2++) {
            arrayList3.add(RegressionModelUtil.createRegressionTable(Collections.singletonList(arrayList2.get(i2)), Collections.singletonList(1), Double.valueOf(0.0d)).setTargetCategory(label.getValue(i2)));
        }
        arrayList3.add(RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), Double.valueOf(1.0d)).setTargetCategory(label.getValue(estimators.size())));
        RegressionModel normalizationMethod = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList3).setNormalizationMethod(RegressionModel.NormalizationMethod.NONE);
        encodePredictProbaOutput(normalizationMethod, DataType.DOUBLE, label);
        arrayList.add(normalizationMethod);
        return MiningModelUtil.createModelChain(arrayList, Segmentation.MissingPredictionTreatment.RETURN_MISSING);
    }

    protected DiscreteLabel encodeLabel(String str, OpType opType, DataType dataType, List<?> list, SkLearnEncoder skLearnEncoder) {
        return super.encodeLabel(str, OpType.ORDINAL, DataType.STRING, list, skLearnEncoder);
    }

    public Classifier getEstimator() {
        return (Classifier) get("estimator", Classifier.class);
    }

    public Map<?, ? extends Classifier> getEstimators() {
        Map dict = getDict("estimators_");
        Function<Object, Object> function = new Function<Object, Object>() { // from class: sklego.meta.OrdinalClassifier.1
            @Override // java.util.function.Function
            public Object apply(Object obj) {
                return Classifier.canonicalizeValue(ScalarUtil.decode(obj));
            }
        };
        CastFunction<Classifier> castFunction = new CastFunction<Classifier>(Classifier.class) { // from class: sklego.meta.OrdinalClassifier.2
            protected String formatMessage(Object obj) {
                return "Dict attribute 'estimators_' contains an unsupported item value (" + ClassDictUtil.formatClass(obj) + ")";
            }
        };
        return (Map) dict.entrySet().stream().collect(Collectors.toMap(entry -> {
            return function.apply(entry.getKey());
        }, entry2 -> {
            return (Classifier) castFunction.apply(entry2.getValue());
        }));
    }

    private Map<?, ?> getEstimatorCategories() {
        if (!containsKey("pmml_classes_")) {
            return null;
        }
        List classes = getClasses("classes_");
        List classes2 = getClasses("pmml_classes_");
        ClassDictUtil.checkSize(new Collection[]{classes, classes2});
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < classes.size(); i++) {
            linkedHashMap.put(classes2.get(i), classes.get(i));
        }
        return linkedHashMap;
    }
}
