package org.jpmml.converter;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Field;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.UnivariateStats;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.FeatureExpander;
import org.jpmml.converter.visitors.ModelCleanerBattery;
import org.jpmml.converter.visitors.PMMLCleanerBattery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/jpmml/converter/ModelEncoder.class */
public class ModelEncoder extends PMMLEncoder {
    private List<Model> transformers = new ArrayList();
    private Map<Model, ListMultimap<String, Decorator>> decorators = new LinkedHashMap();
    private Map<Model, ListMultimap<Feature, Number>> featureImportances = new LinkedHashMap();
    private Map<Model, List<UnivariateStats>> univariateStats = new LinkedHashMap();
    private static final Logger logger = LoggerFactory.getLogger(ModelEncoder.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.converter.ModelEncoder$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/converter/ModelEncoder$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningField$UsageType = new int[MiningField.UsageType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.ACTIVE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    public PMML encodePMML(Model model) {
        PMML encodePMML = encodePMML();
        Model encodeModel = encodeModel(model);
        if (encodeModel != null) {
            encodePMML.addModels(new Model[]{encodeModel});
            new ModelCleanerBattery().applyTo(encodePMML);
            encodeDecorators(encodePMML);
            encodeFeatureImportances(encodePMML);
            encodeUnivariateStats(encodePMML);
        }
        new PMMLCleanerBattery().applyTo(encodePMML);
        return encodePMML;
    }

    public Model encodeModel(Model model) {
        List<Model> transformers = getTransformers();
        if (model != null) {
            transferContent(null, model);
        }
        if (transformers.isEmpty()) {
            return model;
        }
        ArrayList arrayList = new ArrayList(transformers);
        if (model != null) {
            arrayList.add(model);
        }
        MiningModel createModelChain = MiningModelUtil.createModelChain(arrayList, Segmentation.MissingPredictionTreatment.CONTINUE);
        transferUnivariateStats(model, createModelChain);
        return createModelChain;
    }

    public List<Model> getTransformers() {
        return this.transformers;
    }

    public void addTransformer(Model model) {
        this.transformers.add(model);
    }

    public Map<Model, ListMultimap<String, Decorator>> getDecorators() {
        return this.decorators;
    }

    public void addDecorator(Field<?> field, Decorator decorator) {
        addDecorator(null, field, decorator);
    }

    public void addDecorator(Model model, Field<?> field, Decorator decorator) {
        Map<Model, ListMultimap<String, Decorator>> decorators = getDecorators();
        ListMultimap<String, Decorator> listMultimap = decorators.get(model);
        if (listMultimap == null) {
            listMultimap = ArrayListMultimap.create();
            decorators.put(model, listMultimap);
        }
        listMultimap.put(field.requireName(), decorator);
    }

    public Map<Model, ListMultimap<Feature, Number>> getFeatureImportances() {
        return this.featureImportances;
    }

    public void addFeatureImportance(Feature feature, Number number) {
        addFeatureImportance(null, feature, number);
    }

    public void addFeatureImportance(Model model, Feature feature, Number number) {
        Map<Model, ListMultimap<Feature, Number>> featureImportances = getFeatureImportances();
        ListMultimap<Feature, Number> listMultimap = featureImportances.get(model);
        if (listMultimap == null) {
            listMultimap = ArrayListMultimap.create();
            featureImportances.put(model, listMultimap);
        }
        listMultimap.put(feature, number);
    }

    public Map<Model, List<UnivariateStats>> getUnivariateStats() {
        return this.univariateStats;
    }

    public void addUnivariateStats(UnivariateStats univariateStats) {
        addUnivariateStats(null, univariateStats);
    }

    public void addUnivariateStats(Model model, UnivariateStats univariateStats) {
        Map<Model, List<UnivariateStats>> univariateStats2 = getUnivariateStats();
        List<UnivariateStats> list = univariateStats2.get(model);
        if (list == null) {
            list = new ArrayList();
            univariateStats2.put(model, list);
        }
        list.add(univariateStats);
    }

    public void transferContent(Model model, Model model2) {
        transferDecorators(model, model2);
        transferFeatureImportances(model, model2);
        transferUnivariateStats(model, model2);
    }

    public void transferDecorators(Model model, Model model2) {
        transferValue(this.decorators, model, model2);
    }

    public void transferFeatureImportances(Model model, Model model2) {
        transferValue(this.featureImportances, model, model2);
    }

    public void transferUnivariateStats(Model model, Model model2) {
        transferValue(this.univariateStats, model, model2);
    }

    private void encodeDecorators(PMML pmml) {
        Map<Model, ListMultimap<String, Decorator>> decorators = getDecorators();
        if (decorators.isEmpty()) {
            return;
        }
        if (decorators.containsKey(null)) {
            throw new IllegalStateException();
        }
        for (Map.Entry<Model, ListMultimap<String, Decorator>> entry : decorators.entrySet()) {
            Model key = entry.getKey();
            ListMultimap<String, Decorator> value = entry.getValue();
            MiningSchema requireMiningSchema = key.requireMiningSchema();
            if (requireMiningSchema.hasMiningFields()) {
                for (MiningField miningField : requireMiningSchema.getMiningFields()) {
                    List list = value.get(miningField.getName());
                    if (list != null && !list.isEmpty()) {
                        Iterator it = list.iterator();
                        while (it.hasNext()) {
                            ((Decorator) it.next()).decorate(miningField);
                        }
                    }
                }
            }
        }
    }

    private void encodeFeatureImportances(PMML pmml) {
        Map<Model, ListMultimap<Feature, Number>> featureImportances = getFeatureImportances();
        if (featureImportances.isEmpty()) {
            return;
        }
        if (featureImportances.containsKey(null)) {
            throw new IllegalStateException();
        }
        FeatureExpander featureExpander = new FeatureExpander((Map) featureImportances.entrySet().stream().collect(Collectors.toMap(entry -> {
            return (Model) entry.getKey();
        }, entry2 -> {
            return (Set) ((ListMultimap) entry2.getValue()).keySet().stream().map(feature -> {
                return feature.getName();
            }).collect(Collectors.toSet());
        })));
        featureExpander.applyTo(pmml);
        for (Map.Entry<Model, ListMultimap<Feature, Number>> entry3 : featureImportances.entrySet()) {
            Model key = entry3.getKey();
            ListMultimap<Feature, Number> value = entry3.getValue();
            MathContext mathContext = key.getMathContext();
            Collection<Map.Entry> entries = value.entries();
            Map<String, Set<Field<?>>> expandedFeatures = featureExpander.getExpandedFeatures(key);
            if (expandedFeatures == null) {
                throw new IllegalArgumentException();
            }
            ArrayListMultimap create = ArrayListMultimap.create();
            for (Map.Entry entry4 : entries) {
                String name = ((Feature) entry4.getKey()).getName();
                Number number = (Number) entry4.getValue();
                if (!ValueUtil.isZero(number)) {
                    Set<Field<?>> set = expandedFeatures.get(name);
                    if (set == null) {
                        logger.warn("Unused feature '" + name + "' has non-zero importance");
                    } else {
                        Number divide = ValueUtil.divide(mathContext, number, Integer.valueOf(set.size()));
                        Iterator<Field<?>> it = set.iterator();
                        while (it.hasNext()) {
                            create.put(it.next().requireName(), divide);
                        }
                    }
                }
            }
            MiningSchema requireMiningSchema = key.requireMiningSchema();
            if (requireMiningSchema.hasMiningFields()) {
                for (MiningField miningField : requireMiningSchema.getMiningFields()) {
                    String name2 = miningField.getName();
                    switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningField$UsageType[miningField.getUsageType().ordinal()]) {
                        case 1:
                            List list = create.get(name2);
                            if (list != null && !list.isEmpty()) {
                                miningField.setImportance(ValueUtil.sum(mathContext, list));
                                break;
                            }
                            break;
                    }
                }
                ArrayList arrayList = new ArrayList();
                ArrayList arrayList2 = new ArrayList();
                for (Map.Entry entry5 : entries) {
                    arrayList.add(FeatureUtil.getName((Feature) entry5.getKey()));
                    arrayList2.add(entry5.getValue());
                }
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                linkedHashMap.put("data:name", arrayList);
                linkedHashMap.put("data:importance", arrayList2);
                List list2 = (List) arrayList2.stream().filter(number2 -> {
                    return !ValueUtil.isZero(number2);
                }).collect(Collectors.toList());
                InlineTable addExtensions = PMMLUtil.createInlineTable(linkedHashMap).addExtensions(new Extension[]{PMMLUtil.createExtension("numberOfImportances", String.valueOf(arrayList2.size()))}).addExtensions(new Extension[]{PMMLUtil.createExtension("numberOfNonZeroImportances", String.valueOf(list2.size()))}).addExtensions(new Extension[]{PMMLUtil.createExtension("sumOfImportances", String.valueOf(ValueUtil.sum(mathContext, arrayList2)))});
                if (!list2.isEmpty()) {
                    Comparator<Number> comparator = new Comparator<Number>() { // from class: org.jpmml.converter.ModelEncoder.1
                        @Override // java.util.Comparator
                        public int compare(Number number3, Number number4) {
                            return Double.compare(number3.doubleValue(), number4.doubleValue());
                        }
                    };
                    addExtensions.addExtensions(new Extension[]{PMMLUtil.createExtension("minImportance", String.valueOf(Collections.min(list2, comparator)))}).addExtensions(new Extension[]{PMMLUtil.createExtension("maxImportance", String.valueOf(Collections.max(list2, comparator)))});
                }
                requireMiningSchema.addExtensions(new Extension[]{PMMLUtil.createExtension(Extensions.FEATURE_IMPORTANCES, addExtensions)});
            }
        }
    }

    private void encodeUnivariateStats(PMML pmml) {
        Map<Model, List<UnivariateStats>> univariateStats = getUnivariateStats();
        if (univariateStats.isEmpty()) {
            return;
        }
        if (univariateStats.containsKey(null)) {
            throw new IllegalStateException();
        }
        for (Map.Entry<Model, List<UnivariateStats>> entry : univariateStats.entrySet()) {
            Model key = entry.getKey();
            Map map = (Map) entry.getValue().stream().collect(Collectors.toMap((v0) -> {
                return v0.getField();
            }, Function.identity()));
            MiningSchema requireMiningSchema = key.requireMiningSchema();
            if (requireMiningSchema.hasMiningFields()) {
                Iterator it = requireMiningSchema.getMiningFields().iterator();
                while (it.hasNext()) {
                    UnivariateStats univariateStats2 = (UnivariateStats) map.get(((MiningField) it.next()).getName());
                    if (univariateStats2 != null) {
                        ModelUtil.ensureModelStats(key).addUnivariateStats(new UnivariateStats[]{univariateStats2});
                    }
                }
            }
        }
    }

    private static <K, V> void transferValue(Map<K, V> map, K k, K k2) {
        V remove = map.remove(k);
        if (remove != null) {
            map.put(k2, remove);
        }
    }
}
