package org.opencb.oskar.spark.variant.transformers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.variant.AllelesCode;
import org.opencb.biodata.models.variant.Genotype;
import org.opencb.oskar.core.exceptions.OskarException;
import org.opencb.oskar.spark.commons.converters.DataTypeUtils;
import org.opencb.oskar.spark.variant.VariantMetadataManager;
import org.opencb.oskar.spark.variant.converters.VariantToRowConverter;
import org.opencb.oskar.spark.variant.transformers.params.HasStudyId;
import org.opencb.oskar.spark.variant.udf.VariantUdfManager;
import scala.collection.mutable.WrappedArray;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/InbreedingCoefficientTransformer.class */
public class InbreedingCoefficientTransformer extends AbstractTransformer implements HasStudyId {
    private static final StructType STRUCT_TYPE = DataTypes.createStructType(new StructField[]{DataTypes.createStructField("SampleId", DataTypes.StringType, false), DataTypes.createStructField("F", DataTypes.DoubleType, false), DataTypes.createStructField("ObservedHom", DataTypes.IntegerType, false), DataTypes.createStructField("ExpectedHom", DataTypes.DoubleType, false), DataTypes.createStructField("GenotypesCount", DataTypes.IntegerType, false)});
    private final Param<Boolean> missingGenotypesAsHomRefParam;
    private final Param<Boolean> includeMultiAllelicGenotypesParam;
    private final Param<Double> mafThresholdParam;
    private final Param<Integer> stepParam;

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/InbreedingCoefficientTransformer$InbreedingCoefficientUserDefinedAggregationFunction.class */
    public static class InbreedingCoefficientUserDefinedAggregationFunction extends UserDefinedAggregateFunction {
        private final int numSamples;
        private final boolean missingGenotypesAsHomRef;
        private final boolean includeMultiAllelicGenotypes;
        private final List<String> sampleNames;
        private final double mafThreshold;

        public InbreedingCoefficientUserDefinedAggregationFunction(int i, boolean z, boolean z2, double d, List<String> list) {
            this.numSamples = i;
            this.missingGenotypesAsHomRef = z;
            this.includeMultiAllelicGenotypes = z2;
            this.sampleNames = list;
            this.mafThreshold = d;
        }

        public StructType inputSchema() {
            return DataTypes.createStructType(new StructField[]{DataTypes.createStructField("study", VariantToRowConverter.STUDY_DATA_TYPE, false)});
        }

        public StructType bufferSchema() {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.numSamples; i++) {
                arrayList.add(DataTypes.createStructField("expectedHomCount_" + i, DataTypes.DoubleType, false));
                arrayList.add(DataTypes.createStructField("observedHomCount_" + i, DataTypes.IntegerType, false));
                arrayList.add(DataTypes.createStructField("count_" + i, DataTypes.IntegerType, false));
            }
            return DataTypes.createStructType(arrayList);
        }

        public DataType dataType() {
            return DataTypes.createStructType(Collections.singletonList(DataTypes.createStructField("values", DataTypes.createArrayType(InbreedingCoefficientTransformer.STRUCT_TYPE), false)));
        }

        public boolean deterministic() {
            return true;
        }

        public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
            for (int i = 0; i < this.numSamples; i++) {
                mutableAggregationBuffer.update(i * 3, Double.valueOf(0.0d));
                mutableAggregationBuffer.update((i * 3) + 1, 0);
                mutableAggregationBuffer.update((i * 3) + 2, 0);
            }
        }

        public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            float min;
            Row struct = row.getStruct(0);
            Row row2 = (Row) struct.getMap(VariantToRowConverter.STATS_IDX).get("ALL").get();
            if (this.missingGenotypesAsHomRef) {
                int i = row2.getInt(row2.fieldIndex("missingAlleleCount"));
                int i2 = row2.getInt(row2.fieldIndex("alleleCount"));
                min = Math.min((row2.getInt(row2.fieldIndex("refAlleleCount")) + i) / (i2 + i), row2.getInt(row2.fieldIndex("altAlleleCount")) / (i2 + i));
            } else {
                min = Math.min(row2.getFloat(row2.fieldIndex("refAlleleFreq")), row2.getFloat(row2.fieldIndex("altAlleleFreq")));
            }
            if (min > this.mafThreshold) {
                List list = struct.getList(VariantToRowConverter.SAMPLES_DATA_IDX);
                double d = 1.0f - ((2.0f * min) * (1.0f - min));
                int i3 = -1;
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    i3++;
                    Genotype genotype = new Genotype((String) ((WrappedArray) it.next()).apply(0));
                    if (this.includeMultiAllelicGenotypes || genotype.getCode() != AllelesCode.MULTIPLE_ALTERNATES) {
                        if (this.missingGenotypesAsHomRef || genotype.getCode() != AllelesCode.ALLELES_MISSING) {
                            mutableAggregationBuffer.update(i3 * 3, Double.valueOf(mutableAggregationBuffer.getDouble(i3 * 3) + d));
                            if (isHom(genotype)) {
                                mutableAggregationBuffer.update((i3 * 3) + 1, Integer.valueOf(mutableAggregationBuffer.getInt((i3 * 3) + 1) + 1));
                            }
                            mutableAggregationBuffer.update((i3 * 3) + 2, Integer.valueOf(mutableAggregationBuffer.getInt((i3 * 3) + 2) + 1));
                        }
                    }
                }
            }
        }

        private boolean isHom(Genotype genotype) {
            if (genotype.getCode() == AllelesCode.ALLELES_MISSING) {
                return this.missingGenotypesAsHomRef;
            }
            int i = genotype.getAllelesIdx()[0];
            for (int i2 : genotype.getAllelesIdx()) {
                if (i2 != i) {
                    return false;
                }
            }
            return true;
        }

        public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            for (int i = 0; i < this.numSamples; i++) {
                mutableAggregationBuffer.update(i * 3, Double.valueOf(mutableAggregationBuffer.getDouble(i * 3) + row.getDouble(i * 3)));
                mutableAggregationBuffer.update((i * 3) + 1, Integer.valueOf(mutableAggregationBuffer.getInt((i * 3) + 1) + row.getInt((i * 3) + 1)));
                mutableAggregationBuffer.update((i * 3) + 2, Integer.valueOf(mutableAggregationBuffer.getInt((i * 3) + 2) + row.getInt((i * 3) + 2)));
            }
        }

        public Object evaluate(Row row) {
            ArrayList arrayList = new ArrayList(this.numSamples);
            for (int i = 0; i < this.numSamples; i++) {
                double d = row.getDouble(i * 3);
                int i2 = row.getInt((i * 3) + 1);
                int i3 = row.getInt((i * 3) + 2);
                arrayList.add(new GenericRow(new Object[]{this.sampleNames.get(i), Double.valueOf(((float) (i2 - d)) / (i3 - d)), Integer.valueOf(i2), Double.valueOf(d), Integer.valueOf(i3)}));
            }
            return new GenericRow(new Object[]{arrayList});
        }
    }

    public InbreedingCoefficientTransformer() {
        this(null);
    }

    public InbreedingCoefficientTransformer(String str) {
        super(str);
        this.missingGenotypesAsHomRefParam = new Param<>(this, "missingGenotypesAsHomRef", "Treat missing genotypes as HomRef genotypes");
        this.includeMultiAllelicGenotypesParam = new Param<>(this, "includeMultiAllelicGenotypes", "Include multi-allelic variants in the calculation");
        this.mafThresholdParam = new Param<>(this, "mafThreshold", "Include multi-allelic variants in the calculation");
        this.stepParam = new Param<>(this, "step", "Calculate inbreeding coefficient grouping by regions of size step");
        setDefault(this.missingGenotypesAsHomRefParam, false);
        setDefault(this.includeMultiAllelicGenotypesParam, false);
        setDefault(this.mafThresholdParam, Double.valueOf(0.0d));
        setDefault(studyIdParam(), "");
        setDefault(this.stepParam, -1);
    }

    public Param<Boolean> missingGenotypesAsHomRefParam() {
        return this.missingGenotypesAsHomRefParam;
    }

    public InbreedingCoefficientTransformer setMissingGenotypesAsHomRef(boolean z) {
        set(this.missingGenotypesAsHomRefParam, Boolean.valueOf(z));
        return this;
    }

    public Param<Boolean> includeMultiAllelicGenotypesParam() {
        return this.includeMultiAllelicGenotypesParam;
    }

    public InbreedingCoefficientTransformer setIncludeMultiAllelicGenotypes(boolean z) {
        set(this.includeMultiAllelicGenotypesParam, Boolean.valueOf(z));
        return this;
    }

    public Param<Double> mafThresholdParam() {
        return this.mafThresholdParam;
    }

    public InbreedingCoefficientTransformer setMafThreshold(double d) {
        set(this.mafThresholdParam, Double.valueOf(d));
        return this;
    }

    public Param<Integer> stepParam() {
        return this.stepParam;
    }

    public InbreedingCoefficientTransformer setStep(int i) {
        set(this.stepParam, Integer.valueOf(i));
        return this;
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        List<String> next;
        Map<String, List<String>> samples = new VariantMetadataManager().samples(dataset);
        boolean z = samples.size() != 1;
        String studyId = getStudyId();
        if (StringUtils.isNotEmpty(studyId)) {
            next = samples.get(studyId);
            if (next == null) {
                throw OskarException.unknownStudy(studyId, samples.keySet());
            }
        } else {
            if (z) {
                throw OskarException.missingStudy(samples.keySet());
            }
            studyId = samples.keySet().iterator().next();
            next = samples.values().iterator().next();
        }
        InbreedingCoefficientUserDefinedAggregationFunction inbreedingCoefficientUserDefinedAggregationFunction = new InbreedingCoefficientUserDefinedAggregationFunction(next.size(), ((Boolean) getOrDefault(this.missingGenotypesAsHomRefParam)).booleanValue(), ((Boolean) getOrDefault(this.includeMultiAllelicGenotypesParam)).booleanValue(), ((Double) getOrDefault(this.mafThresholdParam)).doubleValue(), next);
        Column study = z ? VariantUdfManager.study("studies", studyId) : functions.col("studies").apply(0);
        Integer num = (Integer) getOrDefault(this.stepParam);
        return num.intValue() > 0 ? dataset.groupBy(new Column[]{functions.col("chromosome"), functions.floor(functions.col("start").divide(num)).multiply(num).cast(DataTypes.IntegerType).alias("start")}).agg(inbreedingCoefficientUserDefinedAggregationFunction.apply(new Column[]{study}).as("r"), new Column[0]).select(new Column[]{functions.col("chromosome"), functions.col("start"), functions.explode(functions.col("r").apply("values")).as("value")}).selectExpr(new String[]{"chromosome", "start", "value.*"}).orderBy("chromosome", new String[]{"start"}) : dataset.agg(inbreedingCoefficientUserDefinedAggregationFunction.apply(new Column[]{study}).as("r"), new Column[0]).select(new Column[]{functions.explode(functions.col("r").apply("values")).as("value")}).selectExpr(new String[]{"value.*"});
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        return ((Integer) getOrDefault(this.stepParam)).intValue() > 0 ? DataTypeUtils.addField(DataTypeUtils.addField(STRUCT_TYPE, DataTypes.createStructField("chromosome", DataTypes.StringType, false), 0), DataTypes.createStructField("start", DataTypes.IntegerType, false), 1) : STRUCT_TYPE;
    }
}
