package sklearn2pmml.pipeline;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import h2o.estimators.BaseEstimator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Header;
import org.dmg.pmml.MiningBuildTask;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Value;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.Classifier;
import sklearn.ClassifierUtil;
import sklearn.Estimator;
import sklearn.HasClassifierOptions;
import sklearn.Initializer;
import sklearn.Step;
import sklearn.StepUtil;
import sklearn.Transformer;
import sklearn.pipeline.FeatureUnion;
import sklearn.pipeline.Pipeline;
import sklearn.pipeline.PipelineClassifier;
import sklearn.pipeline.PipelineRegressor;
import sklearn.pipeline.PipelineTransformer;
import sklearn2pmml.decoration.Domain;

/* loaded from: input_file:sklearn2pmml/pipeline/PMMLPipeline.class */
public class PMMLPipeline extends Pipeline {
    private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class);

    /* renamed from: sklearn2pmml.pipeline.PMMLPipeline$3, reason: invalid class name */
    /* loaded from: input_file:sklearn2pmml/pipeline/PMMLPipeline$3.class */
    static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public PMMLPipeline() {
        this("sklearn2pmml", "PMMLPipeline");
    }

    public PMMLPipeline(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Composite
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        logger.warn(ClassDictUtil.formatClass(this) + " should be replaced with " + ClassDictUtil.formatClass(new Pipeline()) + " in nested workflows");
        return super.encodeFeatures(list, skLearnEncoder);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v28, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v62, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r2v23, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r2v36, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r7v0, types: [sklearn.pipeline.Pipeline, net.razorvine.pickle.objects.ClassDict, sklearn2pmml.pipeline.PMMLPipeline] */
    public PMML encodePMML(SkLearnEncoder skLearnEncoder) {
        OutputField finalResult;
        List<? extends Transformer> transformers = getTransformers();
        Estimator finalEstimator = hasFinalEstimator() ? getFinalEstimator() : null;
        Map<?, ?> header = getHeader();
        Transformer predictTransformer = getPredictTransformer();
        Transformer predictProbaTransformer = getPredictProbaTransformer();
        Transformer applyTransformer = getApplyTransformer();
        List<String> activeFields = getActiveFields();
        List<String> list = null;
        List<String> targetFields = getTargetFields();
        String repr = getRepr();
        Verification verification = getVerification();
        CategoricalLabel categoricalLabel = null;
        if (finalEstimator != null && finalEstimator.isSupervised()) {
            if (targetFields == null) {
                targetFields = initTargetFields();
            }
            ClassDictUtil.checkSize(1, new Collection[]{targetFields});
            String str = targetFields.get(0);
            switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$MiningFunction[finalEstimator.getMiningFunction().ordinal()]) {
                case 1:
                    List<?> classes = ClassifierUtil.getClasses(finalEstimator);
                    Map map = (Map) finalEstimator.getOption(HasClassifierOptions.OPTION_CLASS_EXTENSIONS, null);
                    DataField createDataField = skLearnEncoder.createDataField(FieldName.create(str), OpType.CATEGORICAL, TypeUtil.getDataType(classes, DataType.STRING), classes);
                    ArrayList arrayList = new ArrayList();
                    if (map != null) {
                        for (Map.Entry entry : map.entrySet()) {
                            String str2 = (String) entry.getKey();
                            final Map map2 = (Map) entry.getValue();
                            arrayList.add(new AbstractExtender(str2) { // from class: sklearn2pmml.pipeline.PMMLPipeline.1
                                public VisitorAction visit(Value value) {
                                    Object obj = map2.get(value.getValue());
                                    if (obj != null) {
                                        addExtension(value, ValueUtil.asString(ScalarUtil.decode(obj)));
                                    }
                                    return super.visit(value);
                                }
                            });
                        }
                    }
                    Iterator it = arrayList.iterator();
                    while (it.hasNext()) {
                        ((Visitor) it.next()).applyTo(createDataField);
                    }
                    categoricalLabel = new CategoricalLabel(createDataField);
                    break;
                case 2:
                    categoricalLabel = new ContinuousLabel(skLearnEncoder.createDataField(FieldName.create(str), OpType.CONTINUOUS, DataType.DOUBLE));
                    break;
                default:
                    throw new IllegalArgumentException();
            }
        }
        List arrayList2 = new ArrayList();
        Estimator estimator = finalEstimator;
        try {
            Transformer head = getHead(transformers, finalEstimator);
            if (head != null) {
                if (!(head instanceof Initializer)) {
                    if (activeFields == null) {
                        activeFields = initActiveFields(head);
                    }
                    arrayList2 = initFeatures(activeFields, head.getOpType(), head.getDataType(), skLearnEncoder);
                }
                arrayList2 = super.encodeFeatures(arrayList2, skLearnEncoder);
            } else if (finalEstimator != null) {
                if (activeFields == null) {
                    activeFields = initActiveFields(finalEstimator);
                }
                arrayList2 = initFeatures(activeFields, finalEstimator.getOpType(), finalEstimator.getDataType(), skLearnEncoder);
            }
            if (finalEstimator == null) {
                return encodePMML(header, null, repr, skLearnEncoder);
            }
            StepUtil.checkNumberOfFeatures(finalEstimator, arrayList2);
            Model encode = finalEstimator.encode(new Schema(skLearnEncoder, categoricalLabel, arrayList2));
            if (predictTransformer != null || predictProbaTransformer != null || applyTransformer != null) {
                Output ensureOutput = ModelUtil.ensureOutput(MiningModelUtil.getFinalModel(encode));
                if (predictTransformer != null) {
                    FieldName create = FieldNameUtil.create("predict", new Object[]{categoricalLabel.getName()});
                    if (categoricalLabel instanceof ContinuousLabel) {
                        finalResult = ModelUtil.createPredictedField(create, OpType.CONTINUOUS, categoricalLabel.getDataType()).setFinalResult(false);
                    } else {
                        if (!(categoricalLabel instanceof CategoricalLabel)) {
                            throw new IllegalArgumentException();
                        }
                        finalResult = ModelUtil.createPredictedField(create, OpType.CATEGORICAL, categoricalLabel.getDataType()).setFinalResult(false);
                    }
                    ensureOutput.addOutputFields(new OutputField[]{finalResult});
                    encodeOutput(ensureOutput, Collections.singletonList(finalResult), predictTransformer, skLearnEncoder);
                }
                if (predictProbaTransformer != null) {
                    encodeOutput(ensureOutput, ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()), predictProbaTransformer, skLearnEncoder);
                }
                if (applyTransformer != null) {
                    encodeOutput(ensureOutput, Collections.singletonList(ModelUtil.createEntityIdField(FieldName.create("nodeId")).setDataType(DataType.INTEGER)), applyTransformer, skLearnEncoder);
                }
            }
            if (finalEstimator.isSupervised()) {
                if (verification == null) {
                    logger.warn("Model verification data is not set. Use method '" + ClassDictUtil.formatMember((ClassDict) this, "verify(X)") + "' to correct this deficiency");
                } else {
                    int[] activeValuesShape = verification.getActiveValuesShape();
                    ClassDictUtil.checkShapes(0, (int[][]) new int[]{activeValuesShape, verification.getTargetValuesShape()});
                    ClassDictUtil.checkShapes(1, activeFields.size(), (int[][]) new int[]{activeValuesShape});
                    List<?> activeValues = verification.getActiveValues();
                    List<?> targetValues = verification.getTargetValues();
                    List<? extends Number> list2 = null;
                    boolean hasProbabilityValues = verification.hasProbabilityValues();
                    if (finalEstimator instanceof BaseEstimator ? hasProbabilityValues & ((BaseEstimator) finalEstimator).hasProbabilityDistribution() : finalEstimator instanceof Classifier ? hasProbabilityValues & ((Classifier) finalEstimator).hasProbabilityDistribution() : false) {
                        list = initProbabilityFields(categoricalLabel);
                        int[] probabilityValuesShape = verification.getProbabilityValuesShape();
                        ClassDictUtil.checkShapes(0, (int[][]) new int[]{activeValuesShape, probabilityValuesShape});
                        ClassDictUtil.checkShapes(1, list.size(), (int[][]) new int[]{probabilityValuesShape});
                        list2 = verification.getProbabilityValues();
                    }
                    Number precision = verification.getPrecision();
                    Number zeroThreshold = verification.getZeroThreshold();
                    int i = activeValuesShape[0];
                    LinkedHashMap linkedHashMap = new LinkedHashMap();
                    if (activeFields != null) {
                        for (int i2 = 0; i2 < activeFields.size(); i2++) {
                            VerificationField createVerificationField = ModelUtil.createVerificationField(FieldName.create(activeFields.get(i2)));
                            linkedHashMap.put(createVerificationField, CMatrixUtil.getColumn(cleanValues(skLearnEncoder.getDomain(createVerificationField.getField()), activeValues), i, activeFields.size(), i2));
                        }
                    }
                    if (list != null) {
                        for (int i3 = 0; i3 < list.size(); i3++) {
                            linkedHashMap.put(ModelUtil.createVerificationField(FieldName.create(list.get(i3))).setPrecision(precision).setZeroThreshold(zeroThreshold), CMatrixUtil.getColumn(cleanValues(null, list2), i, list.size(), i3));
                        }
                    } else {
                        for (int i4 = 0; i4 < targetFields.size(); i4++) {
                            VerificationField createVerificationField2 = ModelUtil.createVerificationField(FieldName.create(targetFields.get(i4)));
                            switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$DataType[categoricalLabel.getDataType().ordinal()]) {
                                case 1:
                                case 2:
                                    createVerificationField2.setPrecision(precision).setZeroThreshold(zeroThreshold);
                                    break;
                            }
                            linkedHashMap.put(createVerificationField2, CMatrixUtil.getColumn(cleanValues(skLearnEncoder.getDomain(createVerificationField2.getField()), targetValues), i, targetFields.size(), i4));
                        }
                    }
                    encode.setModelVerification(ModelUtil.createModelVerification(linkedHashMap));
                }
            }
            return encodePMML(header, encode, repr, skLearnEncoder);
        } catch (UnsupportedOperationException e) {
            throw new IllegalArgumentException("The transformer object of the first step (" + ClassDictUtil.formatClass(estimator) + ") does not specify feature type information", e);
        }
    }

    private PMML encodePMML(Map<?, ?> map, Model model, String str, SkLearnEncoder skLearnEncoder) {
        PMML encodePMML = skLearnEncoder.encodePMML(model);
        if (map != null) {
            Header header = encodePMML.getHeader();
            header.setCopyright((String) map.get("copyright"));
            header.setDescription((String) map.get("description"));
            header.setModelVersion((String) map.get("modelVersion"));
        }
        if (str != null) {
            encodePMML.setMiningBuildTask(new MiningBuildTask().addExtensions(new Extension[]{new Extension().addContent(new Object[]{str})}));
        }
        return encodePMML;
    }

    private void encodeOutput(Output output, List<OutputField> list, Transformer transformer, SkLearnEncoder skLearnEncoder) {
        SkLearnEncoder skLearnEncoder2 = new SkLearnEncoder();
        ArrayList arrayList = new ArrayList();
        for (OutputField outputField : list) {
            arrayList.add(new WildcardFeature(skLearnEncoder2, skLearnEncoder2.createDataField(outputField.getName(), outputField.getOpType(), outputField.getDataType())));
        }
        transformer.encode(arrayList, skLearnEncoder2);
        for (DerivedField derivedField : skLearnEncoder2.getDerivedFields().values()) {
            output.addOutputFields(new OutputField[]{new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression())});
        }
        Iterator it = skLearnEncoder2.getDefineFunctions().values().iterator();
        while (it.hasNext()) {
            skLearnEncoder.addDefineFunction((DefineFunction) it.next());
        }
    }

    @Override // sklearn.pipeline.Pipeline
    public List<Object[]> getSteps() {
        return super.getSteps();
    }

    public PMMLPipeline setSteps(List<Object[]> list) {
        put("steps", list);
        return this;
    }

    public Map<?, ?> getHeader() {
        return (Map) getOptional("header", Map.class);
    }

    public Transformer getPredictTransformer() {
        return getTransformer("predict_transformer");
    }

    public Transformer getPredictProbaTransformer() {
        return getTransformer("predict_proba_transformer");
    }

    public Transformer getApplyTransformer() {
        return getTransformer("apply_transformer");
    }

    private Transformer getTransformer(String str) {
        return (Transformer) getOptional(str, Transformer.class);
    }

    public List<String> getActiveFields() {
        if (containsKey("active_fields")) {
            return getListLike("active_fields", String.class);
        }
        return null;
    }

    public PMMLPipeline setActiveFields(List<String> list) {
        put("active_fields", toArray(list));
        return this;
    }

    public List<String> getTargetFields() {
        if (containsKey("target_field")) {
            return Collections.singletonList(getOptionalString("target_field"));
        }
        if (containsKey("target_fields")) {
            return getListLike("target_fields", String.class);
        }
        return null;
    }

    public PMMLPipeline setTargetFields(List<String> list) {
        put("target_fields", toArray(list));
        return this;
    }

    public String getRepr() {
        return getOptionalString("repr_");
    }

    public PMMLPipeline setRepr(String str) {
        put("repr_", str);
        return this;
    }

    public Verification getVerification() {
        return (Verification) getOptional("verification", Verification.class);
    }

    public PMMLPipeline setVerification(Verification verification) {
        put("verification", verification);
        return this;
    }

    private List<String> initActiveFields(Step step) {
        int numberOfFeatures = step.getNumberOfFeatures();
        if (numberOfFeatures == -1) {
            throw new IllegalArgumentException("The transformer object of the first step (" + ClassDictUtil.formatClass(step) + ") does not specify the number of input features");
        }
        ArrayList arrayList = new ArrayList(numberOfFeatures);
        for (int i = 0; i < numberOfFeatures; i++) {
            arrayList.add("x" + String.valueOf(i + 1));
        }
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "active_fields") + "' is not set. Assuming {} as the names of active fields", arrayList);
        return arrayList;
    }

    private List<String> initProbabilityFields(CategoricalLabel categoricalLabel) {
        ArrayList arrayList = new ArrayList();
        Iterator it = categoricalLabel.getValues().iterator();
        while (it.hasNext()) {
            arrayList.add("probability(" + it.next() + ")");
        }
        return arrayList;
    }

    private List<String> initTargetFields() {
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "target_fields") + "' is not set. Assuming {} as the name of the target field", "y");
        return Collections.singletonList("y");
    }

    private static List<Feature> initFeatures(List<String> list, OpType opType, DataType dataType, SkLearnEncoder skLearnEncoder) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new WildcardFeature(skLearnEncoder, skLearnEncoder.createDataField(FieldName.create(it.next()), opType, dataType)));
        }
        return arrayList;
    }

    private static Transformer getHead(List<? extends Transformer> list, Estimator estimator) {
        if (!list.isEmpty()) {
            Transformer transformer = list.get(0);
            return transformer instanceof FeatureUnion ? getHead(((FeatureUnion) transformer).getTransformers(), null) : transformer instanceof PipelineTransformer ? getHead(((PipelineTransformer) transformer).getPipeline().getTransformers(), null) : transformer;
        }
        if (estimator == null) {
            return null;
        }
        if (estimator instanceof PipelineClassifier) {
            Pipeline pipeline = ((PipelineClassifier) estimator).getPipeline();
            return getHead(pipeline.getTransformers(), pipeline.getFinalEstimator());
        }
        if (!(estimator instanceof PipelineRegressor)) {
            return null;
        }
        Pipeline pipeline2 = ((PipelineRegressor) estimator).getPipeline();
        return getHead(pipeline2.getTransformers(), pipeline2.getFinalEstimator());
    }

    private static List<?> cleanValues(Domain domain, List<?> list) {
        return Lists.transform(list, new Function<Object, Object>() { // from class: sklearn2pmml.pipeline.PMMLPipeline.2
            public Object apply(Object obj) {
                Domain.checkValue(obj);
                if (ValueUtil.isNaN(obj)) {
                    return null;
                }
                return obj;
            }
        });
    }
}
