package org.jpmml.evaluator.spark;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.expressions.CreateStruct;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.InputField;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/jpmml/evaluator/spark/PMMLTransformer.class */
public class PMMLTransformer extends Transformer {
    private String outputCol = "pmml";
    private Evaluator evaluator = null;
    private List<ColumnProducer<?>> columnProducers = null;
    private StructType outputSchema = null;

    /* loaded from: input_file:org/jpmml/evaluator/spark/PMMLTransformer$SerializableAbstractFunction1.class */
    public static abstract class SerializableAbstractFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable {
    }

    public PMMLTransformer(Evaluator evaluator, List<ColumnProducer<?>> list) {
        StructType structType = new StructType();
        Iterator<ColumnProducer<?>> it = list.iterator();
        while (it.hasNext()) {
            structType = structType.add(it.next().init(evaluator));
        }
        setEvaluator(evaluator);
        setColumnProducers(list);
        setOutputSchema(structType);
    }

    public String uid() {
        return null;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public PMMLTransformer m5copy(ParamMap paramMap) {
        throw new UnsupportedOperationException();
    }

    public StructType transformSchema(StructType structType) {
        return structType.add(DataTypes.createStructField(getOutputCol(), getOutputSchema(), false));
    }

    public DataFrame transform(final DataFrame dataFrame) {
        final Evaluator evaluator = getEvaluator();
        final List<ColumnProducer<?>> columnProducers = getColumnProducers();
        final List inputFields = evaluator.getInputFields();
        return dataFrame.withColumn(DataFrameUtil.escapeColumnName(getOutputCol()), new Column(new ScalaUDF(new SerializableAbstractFunction1<Row, Row>() { // from class: org.jpmml.evaluator.spark.PMMLTransformer.2
            public Row apply(Row row) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i = 0; i < inputFields.size(); i++) {
                    InputField inputField = (InputField) inputFields.get(i);
                    linkedHashMap.put(inputField.getName(), inputField.prepare(row.get(i)));
                }
                Map evaluate = evaluator.evaluate(linkedHashMap);
                ArrayList arrayList = new ArrayList(columnProducers.size());
                for (int i2 = 0; i2 < columnProducers.size(); i2++) {
                    ColumnProducer columnProducer = (ColumnProducer) columnProducers.get(i2);
                    arrayList.add(columnProducer.format(evaluate.get(columnProducer.getField().getName())));
                }
                return RowFactory.create(arrayList.toArray());
            }
        }, getOutputSchema(), ScalaUtil.singletonSeq(new CreateStruct(ScalaUtil.toSeq(Lists.newArrayList(Lists.transform(inputFields, new Function<InputField, Expression>() { // from class: org.jpmml.evaluator.spark.PMMLTransformer.1
            public Expression apply(InputField inputField) {
                return dataFrame.apply(DataFrameUtil.escapeColumnName(inputField.getName().getValue())).expr();
            }
        }))))), ScalaUtil.emptySeq())));
    }

    public String[] getInputCols() {
        ArrayList newArrayList = Lists.newArrayList(Lists.transform(getEvaluator().getActiveFields(), new Function<InputField, String>() { // from class: org.jpmml.evaluator.spark.PMMLTransformer.3
            public String apply(InputField inputField) {
                return inputField.getName().getValue();
            }
        }));
        return (String[]) newArrayList.toArray(new String[newArrayList.size()]);
    }

    public String getOutputCol() {
        return this.outputCol;
    }

    public void setOutputCol(String str) {
        if (str == null) {
            throw new IllegalArgumentException();
        }
        this.outputCol = str;
    }

    public Evaluator getEvaluator() {
        return this.evaluator;
    }

    private void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }

    public List<ColumnProducer<?>> getColumnProducers() {
        return this.columnProducers;
    }

    private void setColumnProducers(List<ColumnProducer<?>> list) {
        this.columnProducers = list;
    }

    public StructType getOutputSchema() {
        return this.outputSchema;
    }

    private void setOutputSchema(StructType structType) {
        this.outputSchema = structType;
    }
}
