package org.opencb.oskar.spark.variant.transformers;

import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.oskar.spark.commons.converters.DataTypeUtils;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/HistogramTransformer.class */
public class HistogramTransformer extends AbstractTransformer {
    private Param<Double> stepParam;
    private Param<String> inputColParam;

    public HistogramTransformer() {
        this(null);
    }

    public HistogramTransformer(String str) {
        super(str);
        setDefault(stepParam(), Double.valueOf(0.1d));
    }

    public HistogramTransformer setStep(double d) {
        set(stepParam(), Double.valueOf(d));
        return this;
    }

    public HistogramTransformer setInputCol(String str) {
        set(inputColParam(), str);
        return this;
    }

    public Param<Double> stepParam() {
        if (this.stepParam != null) {
            return this.stepParam;
        }
        Param<Double> param = new Param<>(uid(), "step", "");
        this.stepParam = param;
        return param;
    }

    public Param<String> inputColParam() {
        if (this.inputColParam != null) {
            return this.inputColParam;
        }
        Param<String> param = new Param<>(uid(), "inputCol", "");
        this.inputColParam = param;
        return param;
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        String str = (String) get(inputColParam()).get();
        Double d = (Double) get(stepParam()).get();
        DataType dataType = DataTypeUtils.getField(dataset.schema(), str).dataType();
        if (dataType instanceof NumericType) {
            return dataset.select(new Column[]{functions.expr(str).divide(d).cast(DataTypes.IntegerType).multiply(d).cast(dataType).alias(str)}).groupBy(str, new String[0]).count().orderBy(str, new String[0]);
        }
        throw new IllegalArgumentException("Input column must be NumericalType. Input column '" + str + "' is type " + dataType.typeName());
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        String str = (String) get(inputColParam()).get();
        StructType structType2 = new StructType();
        StructField[] fields = structType.fields();
        int length = fields.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            StructField structField = fields[i];
            if (structField.name().equals(str)) {
                structType2.add(structField);
                break;
            }
            i++;
        }
        if (structType2.fields().length == 0) {
            structType2.add("inputCol", DataTypes.DoubleType, false);
        }
        structType2.add("count", DataTypes.IntegerType, false);
        return structType2;
    }
}
