package sklearn2pmml;

import com.google.common.base.CharMatcher;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.TupleUtil;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.TypeUtil;
import sklearn.pipeline.Pipeline;
import sklearn_pandas.DataFrameMapper;

/* loaded from: input_file:sklearn2pmml/PMMLPipeline.class */
public class PMMLPipeline extends Pipeline {
    public PMMLPipeline(String str, String str2) {
        super(str, str2);
    }

    public PMML encodePMML() {
        Estimator estimator;
        DataFrameMapper mapper = getMapper();
        Estimator estimator2 = getEstimator();
        while (true) {
            estimator = estimator2;
            if (!(estimator instanceof Pipeline)) {
                break;
            }
            estimator2 = ((Pipeline) estimator).getEstimator();
        }
        SkLearnEncoder skLearnEncoder = new SkLearnEncoder();
        CategoricalLabel categoricalLabel = null;
        if (estimator.isSupervised()) {
            String targetField = getTargetField();
            if (targetField == null) {
                targetField = "y";
            }
            OpType opType = OpType.CONTINUOUS;
            DataType dataType = DataType.DOUBLE;
            List<String> list = null;
            if (estimator instanceof Classifier) {
                List<?> classes = ((Classifier) estimator).getClasses();
                if (classes == null || classes.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                opType = OpType.CATEGORICAL;
                dataType = TypeUtil.getDataType(classes, DataType.STRING);
                list = formatTargetCategories(classes);
            }
            DataField createDataField = skLearnEncoder.createDataField(FieldName.create(targetField), opType, dataType, list);
            categoricalLabel = (list == null || list.size() <= 0) ? new ContinuousLabel(createDataField) : new CategoricalLabel(createDataField);
        }
        if (mapper != null) {
            mapper.encodeFeatures(skLearnEncoder);
        } else {
            List<String> activeFields = getActiveFields();
            if (activeFields == null) {
                activeFields = new ArrayList();
                int numberOfFeatures = getNumberOfFeatures();
                for (int i = 0; i < numberOfFeatures; i++) {
                    activeFields.add("x" + String.valueOf(i + 1));
                }
            }
            OpType opType2 = getOpType();
            DataType dataType2 = getDataType();
            for (String str : activeFields) {
                skLearnEncoder.addRow(Collections.singletonList(str), Collections.singletonList(new WildcardFeature(skLearnEncoder, skLearnEncoder.createDataField(FieldName.create(str), opType2, dataType2))));
            }
        }
        Iterator<DefineFunction> it = encodeDefineFunctions().iterator();
        while (it.hasNext()) {
            skLearnEncoder.addDefineFunction(it.next());
        }
        return skLearnEncoder.encodePMML(encodeModel(new Schema(categoricalLabel, skLearnEncoder.getFeatures()), skLearnEncoder));
    }

    public DataFrameMapper getMapper() {
        Object[] mapperStep = getMapperStep();
        if (mapperStep != null) {
            return (DataFrameMapper) TupleUtil.extractElement(mapperStep, 1);
        }
        return null;
    }

    public Object[] getMapperStep() {
        List<Object[]> transformerSteps = super.getTransformerSteps();
        if (transformerSteps.size() <= 0 || !(TupleUtil.extractElement(transformerSteps.get(0), 1) instanceof DataFrameMapper)) {
            return null;
        }
        return transformerSteps.get(0);
    }

    @Override // sklearn.pipeline.Pipeline
    public List<Object[]> getTransformerSteps() {
        List<Object[]> transformerSteps = super.getTransformerSteps();
        if (transformerSteps.size() > 0 && (TupleUtil.extractElement(transformerSteps.get(0), 1) instanceof DataFrameMapper)) {
            transformerSteps = transformerSteps.subList(1, transformerSteps.size());
        }
        return transformerSteps;
    }

    public List<String> getActiveFields() {
        return ClassDictUtil.getArray(this, "active_fields");
    }

    public String getTargetField() {
        return (String) get("target_field");
    }

    private static List<String> formatTargetCategories(List<?> list) {
        return new ArrayList(Lists.transform(list, new Function<Object, String>() { // from class: sklearn2pmml.PMMLPipeline.1
            /* renamed from: apply, reason: merged with bridge method [inline-methods] */
            public String m33apply(Object obj) {
                String formatValue = ValueUtil.formatValue(obj);
                if (formatValue == null || CharMatcher.WHITESPACE.matchesAnyOf(formatValue)) {
                    throw new IllegalArgumentException(formatValue);
                }
                return formatValue;
            }
        }));
    }
}
