package sklego.meta;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasApplyField;
import sklearn.HasDecisionFunctionField;
import sklearn.HasEstimator;
import sklearn.HasMultiApplyField;
import sklearn.HasPredictField;
import sklearn.Transformer;
import sklearn.tree.HasTreeOptions;

/* loaded from: input_file:sklego/meta/EstimatorTransformer.class */
public class EstimatorTransformer extends Transformer implements HasEstimator<Estimator> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: sklego.meta.EstimatorTransformer$1, reason: invalid class name */
    /* loaded from: input_file:sklego/meta/EstimatorTransformer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$ResultFeature;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$OpType = new int[OpType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CATEGORICAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CONTINUOUS.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$dmg$pmml$ResultFeature = new int[ResultFeature.values().length];
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.PREDICTED_VALUE.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.TRANSFORMED_VALUE.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.DECISION.ordinal()] = 3;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.ENTITY_ID.ordinal()] = 4;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    public EstimatorTransformer(String str, String str2) {
        super(str, str2);
    }

    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        List list2;
        CategoricalFeature continuousFeature;
        HasTreeOptions estimator = getEstimator();
        String predictFunc = getPredictFunc();
        boolean z = -1;
        switch (predictFunc.hashCode()) {
            case -318720807:
                if (predictFunc.equals("predict")) {
                    z = 2;
                    break;
                }
                break;
            case 93029230:
                if (predictFunc.equals("apply")) {
                    z = false;
                    break;
                }
                break;
            case 266243291:
                if (predictFunc.equals("decision_function")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (estimator instanceof HasTreeOptions) {
                    estimator.putOption("winner_id", Boolean.TRUE);
                }
                if (estimator instanceof HasApplyField) {
                    list2 = Collections.singletonList(((HasApplyField) estimator).getApplyField());
                    break;
                } else {
                    if (!(estimator instanceof HasMultiApplyField)) {
                        throw new IllegalArgumentException();
                    }
                    list2 = ((HasMultiApplyField) estimator).getApplyFields();
                    break;
                }
            case true:
                if (!(estimator instanceof HasDecisionFunctionField)) {
                    throw new IllegalArgumentException();
                }
                list2 = Collections.singletonList(((HasDecisionFunctionField) estimator).getDecisionFunctionField());
                break;
            case true:
                if (!(estimator instanceof HasPredictField)) {
                    list2 = null;
                    break;
                } else {
                    list2 = Collections.singletonList(((HasPredictField) estimator).getPredictField());
                    break;
                }
            default:
                throw new IllegalArgumentException(predictFunc);
        }
        Schema createSchema = createSchema(estimator, list, skLearnEncoder);
        Model encode = estimator.encode(createSchema);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Output output = encode.getOutput();
        if (output != null && output.hasOutputFields()) {
            Iterator it = output.getOutputFields().iterator();
            while (it.hasNext()) {
                OutputField outputField = (OutputField) it.next();
                switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$ResultFeature[outputField.getResultFeature().ordinal()]) {
                    case 1:
                    case 2:
                    case 3:
                    case 4:
                        DerivedOutputField createDerivedField = skLearnEncoder.createDerivedField(encode, outputField, true);
                        linkedHashMap.put(createDerivedField.getName(), createDerivedField);
                        break;
                }
                it.remove();
            }
        }
        skLearnEncoder.addTransformer(encode);
        if (list2 == null) {
            if (estimator.isSupervised()) {
                if (!linkedHashMap.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                CategoricalLabel label = createSchema.getLabel();
                String createFieldName = createFieldName("predict", new Object[0]);
                switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[estimator.getMiningFunction().ordinal()]) {
                    case 1:
                        CategoricalLabel categoricalLabel = label;
                        return Collections.singletonList(new CategoricalFeature(skLearnEncoder, skLearnEncoder.createDerivedField(encode, ModelUtil.createPredictedField(createFieldName, OpType.CATEGORICAL, categoricalLabel.getDataType()), false), categoricalLabel.getValues()));
                    case 2:
                        return Collections.singletonList(new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(encode, ModelUtil.createPredictedField(createFieldName, OpType.CONTINUOUS, ((ContinuousLabel) label).getDataType()), false)));
                    default:
                        throw new IllegalArgumentException();
                }
            }
            if (linkedHashMap.isEmpty()) {
                throw new IllegalArgumentException();
            }
            list2 = Collections.singletonList(Iterables.getLast(linkedHashMap.keySet()));
        }
        ArrayList arrayList = new ArrayList();
        Iterator it2 = list2.iterator();
        while (it2.hasNext()) {
            DerivedOutputField derivedOutputField = (DerivedOutputField) linkedHashMap.get((String) it2.next());
            if (derivedOutputField == null) {
                throw new IllegalArgumentException();
            }
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$OpType[derivedOutputField.getOpType().ordinal()]) {
                case 1:
                    continuousFeature = new CategoricalFeature(skLearnEncoder, derivedOutputField.getOutputField());
                    break;
                case 2:
                    continuousFeature = new ContinuousFeature(skLearnEncoder, derivedOutputField);
                    break;
                default:
                    throw new IllegalArgumentException();
            }
            arrayList.add(continuousFeature);
        }
        return arrayList;
    }

    public Estimator getEstimator() {
        return (Estimator) get("estimator_", Estimator.class);
    }

    public String getPredictFunc() {
        return getString("predict_func");
    }

    private static Schema createSchema(Estimator estimator, List<Feature> list, SkLearnEncoder skLearnEncoder) {
        CategoricalLabel categoricalLabel = null;
        if (estimator.isSupervised()) {
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[estimator.getMiningFunction().ordinal()]) {
                case 1:
                    List classes = EstimatorUtil.getClasses(estimator);
                    categoricalLabel = new CategoricalLabel(TypeUtil.getDataType(classes, DataType.STRING), classes);
                    break;
                case 2:
                    categoricalLabel = new ContinuousLabel(DataType.DOUBLE);
                    break;
                default:
                    throw new IllegalArgumentException();
            }
        }
        return new Schema(skLearnEncoder, categoricalLabel, list);
    }
}
