package org.jpmml.sparkml.model;

import java.util.ArrayList;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PPMatrix;
import org.dmg.pmml.general_regression.ParamMatrix;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.VectorUtil;

/* loaded from: input_file:org/jpmml/sparkml/model/LinearModelUtil.class */
public class LinearModelUtil {
    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Vector;DLorg/jpmml/converter/Schema;)Lorg/dmg/pmml/Model; */
    public static Model createRegression(ModelConverter modelConverter, Vector vector, double d, Schema schema) {
        ContinuousLabel label = schema.getLabel();
        String str = (String) modelConverter.getOption(HasRegressionTableOptions.OPTION_REPRESENTATION, null);
        ArrayList arrayList = new ArrayList(schema.getFeatures());
        ArrayList arrayList2 = new ArrayList(VectorUtil.toList(vector));
        RegressionTableUtil.simplify(modelConverter, null, arrayList, arrayList2);
        if (str == null || !GeneralRegressionModel.class.getSimpleName().equalsIgnoreCase(str)) {
            return RegressionModelUtil.createRegression(arrayList, arrayList2, Double.valueOf(d), RegressionModel.NormalizationMethod.NONE, schema);
        }
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.REGRESSION, MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), (ParameterList) null, (PPMatrix) null, (ParamMatrix) null);
        GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, arrayList, arrayList2, Double.valueOf(d), (Object) null);
        return generalRegressionModel;
    }

    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Vector;DLorg/jpmml/converter/Schema;)Lorg/dmg/pmml/Model; */
    public static Model createBinaryLogisticClassification(ModelConverter modelConverter, Vector vector, double d, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        String str = (String) modelConverter.getOption(HasRegressionTableOptions.OPTION_REPRESENTATION, null);
        ArrayList arrayList = new ArrayList(schema.getFeatures());
        ArrayList arrayList2 = new ArrayList(VectorUtil.toList(vector));
        RegressionTableUtil.simplify(modelConverter, null, arrayList, arrayList2);
        if (str == null || !GeneralRegressionModel.class.getSimpleName().equalsIgnoreCase(str)) {
            return RegressionModelUtil.createBinaryLogisticClassification(arrayList, arrayList2, Double.valueOf(d), RegressionModel.NormalizationMethod.LOGIT, true, schema);
        }
        Object value = label.getValue(1);
        GeneralRegressionModel linkFunction = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), (ParameterList) null, (PPMatrix) null, (ParamMatrix) null).setLinkFunction(GeneralRegressionModel.LinkFunction.LOGIT);
        GeneralRegressionModelUtil.encodeRegressionTable(linkFunction, arrayList, arrayList2, Double.valueOf(d), value);
        return linkFunction;
    }

    /* JADX WARN: Incorrect types in method signature: <C:Lorg/jpmml/sparkml/ModelConverter<*>;:Lorg/jpmml/sparkml/model/HasRegressionTableOptions;>(TC;Lorg/apache/spark/ml/linalg/Matrix;Lorg/apache/spark/ml/linalg/Vector;Lorg/jpmml/converter/Schema;)Lorg/dmg/pmml/Model; */
    public static Model createSoftmaxClassification(ModelConverter modelConverter, Matrix matrix, Vector vector, Schema schema) {
        CategoricalLabel label = schema.getLabel();
        MatrixUtil.checkRows(label.size(), matrix);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < label.size(); i++) {
            Object value = label.getValue(i);
            ArrayList arrayList2 = new ArrayList(schema.getFeatures());
            ArrayList arrayList3 = new ArrayList(MatrixUtil.getRow(matrix, i));
            RegressionTableUtil.simplify(modelConverter, value, arrayList2, arrayList3);
            arrayList.add(RegressionModelUtil.createRegressionTable(arrayList2, arrayList3, Double.valueOf(vector.apply(i))).setTargetCategory(value));
        }
        return new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
    }
}
