package org.jpmml.evaluator.spark;

import java.util.List;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.HasProbability;
import org.jpmml.evaluator.TargetField;

/* loaded from: input_file:org/jpmml/evaluator/spark/ProbabilityColumnProducer.class */
class ProbabilityColumnProducer extends ColumnProducer<TargetField> {
    private List<String> labels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public ProbabilityColumnProducer(TargetField targetField, String str, List<String> list) {
        super(targetField, str != null ? str : "probability");
        this.labels = null;
        setLabels(list);
    }

    @Override // org.jpmml.evaluator.spark.ColumnProducer
    public StructField init(Evaluator evaluator) {
        return DataTypes.createStructField(getColumnName(), new VectorUDT(), false);
    }

    @Override // org.jpmml.evaluator.spark.ColumnProducer
    public Vector format(Object obj) {
        List<String> labels = getLabels();
        HasProbability hasProbability = (HasProbability) obj;
        double[] dArr = new double[labels.size()];
        for (int i = 0; i < labels.size(); i++) {
            dArr[i] = hasProbability.getProbability(labels.get(i)).doubleValue();
        }
        return new DenseVector(dArr);
    }

    public List<String> getLabels() {
        return this.labels;
    }

    private void setLabels(List<String> list) {
        this.labels = list;
    }
}
