package sklearn.linear_model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Classifier;
import sklearn.EstimatorUtil;

/* loaded from: input_file:sklearn/linear_model/BaseLinearClassifier.class */
public abstract class BaseLinearClassifier extends Classifier {
    public BaseLinearClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return getCoefShape()[1];
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo20encodeModel(Schema schema) {
        int[] coefShape = getCoefShape();
        int i = coefShape[0];
        int i2 = coefShape[1];
        boolean hasProbabilityDistribution = hasProbabilityDistribution();
        List<? extends Number> coef = getCoef();
        List<? extends Number> intercept = getIntercept();
        Schema anonymousSchema = schema.toAnonymousSchema();
        CategoricalLabel label = anonymousSchema.getLabel();
        if (i == 1) {
            EstimatorUtil.checkSize(2, label);
            return MiningModelUtil.createBinaryLogisticClassification(schema, encodeCategoryRegressor(label.getValue(1), MatrixUtil.getRow(coef, i, i2, 0), intercept.get(0), null, anonymousSchema), -1.0d, hasProbabilityDistribution);
        }
        if (i < 2) {
            throw new IllegalArgumentException();
        }
        EstimatorUtil.checkSize(i, label);
        ArrayList arrayList = new ArrayList();
        int size = label.size();
        for (int i3 = 0; i3 < size; i3++) {
            arrayList.add(encodeCategoryRegressor(label.getValue(i3), MatrixUtil.getRow(coef, i, i2, i3), intercept.get(i3), "logit", anonymousSchema));
        }
        return MiningModelUtil.createClassification(schema, arrayList, RegressionModel.NormalizationMethod.SIMPLEMAX, hasProbabilityDistribution);
    }

    @Override // sklearn.Estimator
    public Set<DefineFunction> encodeDefineFunctions() {
        return Collections.singleton(EstimatorUtil.encodeLogitFunction());
    }

    public List<? extends Number> getCoef() {
        return ClassDictUtil.getArray(this, "coef_");
    }

    public List<? extends Number> getIntercept() {
        return ClassDictUtil.getArray(this, "intercept_");
    }

    private int[] getCoefShape() {
        return ClassDictUtil.getShape(this, "coef_", 2);
    }

    private static RegressionModel encodeCategoryRegressor(String str, List<? extends Number> list, Number number, String str2, Schema schema) {
        OutputField finalResult = new OutputField(FieldName.create("decisionFunction_" + str), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.PREDICTED_VALUE).setFinalResult(false);
        Output addOutputFields = new Output().addOutputFields(new OutputField[]{finalResult});
        if (str2 != null) {
            addOutputFields.addOutputFields(new OutputField[]{new OutputField(FieldName.create(str2 + "DecisionFunction_" + str), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(false).setExpression(PMMLUtil.createApply(str2, new Expression[]{new FieldRef(finalResult.getName())}))});
        }
        return BaseLinearUtil.encodeRegressionModel(number, list, schema).setOutput(addOutputFields);
    }
}
