package org.opencb.oskar.spark.variant.transformers;

import com.databricks.spark.avro.SchemaConverters;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.variant.metadata.VariantFileMetadata;
import org.opencb.biodata.models.variant.metadata.VariantSetStats;
import org.opencb.biodata.models.variant.stats.VariantStats;
import org.opencb.oskar.core.exceptions.OskarException;
import org.opencb.oskar.spark.variant.VariantMetadataManager;
import org.opencb.oskar.spark.variant.converters.VariantToRowConverter;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import org.opencb.oskar.spark.variant.transformers.params.HasStudyId;
import org.opencb.oskar.spark.variant.udf.VariantUdfManager;
import scala.Option;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.mutable.HashMap;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/VariantSetStatsTransformer.class */
public class VariantSetStatsTransformer extends AbstractTransformer implements HasStudyId {
    private Param<String> fileIdParam;
    private Param<List<String>> samplesParam;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/VariantSetStatsTransformer$VariantSetStatsBufferUtils.class */
    public static class VariantSetStatsBufferUtils {
        static final StructType VARIANT_SET_BUFFER_SCHEMA = DataTypes.createStructType(new StructField[]{DataTypes.createStructField(SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_COLNAME, DataTypes.IntegerType, false), DataTypes.createStructField(SampleVariantStatsTransformer.BufferUtils.NUM_PASS_COLNAME, DataTypes.IntegerType, false), DataTypes.createStructField("transitionsCount", DataTypes.IntegerType, false), DataTypes.createStructField("transversionsCount", DataTypes.IntegerType, false), DataTypes.createStructField("qualCount", DataTypes.DoubleType, false), DataTypes.createStructField("qualSum", DataTypes.DoubleType, false), DataTypes.createStructField("qualSumSq", DataTypes.DoubleType, false), DataTypes.createStructField("variantTypeCounts", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType, false), false), DataTypes.createStructField("variantBiotypeCounts", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType, false), false), DataTypes.createStructField("consequenceTypesCounts", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType, false), false), DataTypes.createStructField("byChromosomeCounts", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType, false), false)});

        private VariantSetStatsBufferUtils() {
        }

        public static void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
            setNumVariants(mutableAggregationBuffer, 0);
            setNumPass(mutableAggregationBuffer, 0);
            setTransitionsCount(mutableAggregationBuffer, 0);
            setTransversionsCount(mutableAggregationBuffer, 0);
            setQualCount(mutableAggregationBuffer, 0.0d);
            setQualSum(mutableAggregationBuffer, 0.0d);
            setQualSumSq(mutableAggregationBuffer, 0.0d);
            setVariantTypeCounts(mutableAggregationBuffer, new HashMap());
            setVariantBiotypeCounts(mutableAggregationBuffer, new HashMap());
            setConsequenceTypesCounts(mutableAggregationBuffer, new HashMap());
            setByChromosomeCounts(mutableAggregationBuffer, new HashMap());
        }

        public static void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            addNumVariants(mutableAggregationBuffer, getNumVariants(row));
            addNumPass(mutableAggregationBuffer, getNumPass(row));
            addTransitionsCount(mutableAggregationBuffer, getTransitionsCount(row));
            addTransversionsCount(mutableAggregationBuffer, getTransversionsCount(row));
            addQualCount(mutableAggregationBuffer, getQualCount(row));
            addQualSum(mutableAggregationBuffer, getQualSum(row));
            addQualSumSq(mutableAggregationBuffer, getQualSumSq(row));
            setConsequenceTypesCounts(mutableAggregationBuffer, getConsequenceTypesCounts(mutableAggregationBuffer).$plus$plus(getConsequenceTypesCounts(row)));
            setVariantBiotypeCounts(mutableAggregationBuffer, getVariantBiotypeCounts(mutableAggregationBuffer).$plus$plus(getVariantBiotypeCounts(row)));
            setVariantTypeCounts(mutableAggregationBuffer, getVariantTypeCounts(mutableAggregationBuffer).$plus$plus(getVariantTypeCounts(row)));
            setByChromosomeCounts(mutableAggregationBuffer, getByChromosomeCounts(mutableAggregationBuffer).$plus$plus(getByChromosomeCounts(row)));
        }

        public static int getNumVariants(Row row) {
            return row.getInt(0);
        }

        public static void setNumVariants(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(0, Integer.valueOf(i));
        }

        public static void addNumVariants(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            setNumVariants(mutableAggregationBuffer, getNumVariants(mutableAggregationBuffer) + i);
        }

        public static int getNumPass(Row row) {
            return row.getInt(1);
        }

        public static void setNumPass(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(1, Integer.valueOf(i));
        }

        public static void addNumPass(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            setNumPass(mutableAggregationBuffer, getNumPass(mutableAggregationBuffer) + i);
        }

        public static int getTransitionsCount(Row row) {
            return row.getInt(2);
        }

        public static void setTransitionsCount(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(2, Integer.valueOf(i));
        }

        public static void addTransitionsCount(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            setTransitionsCount(mutableAggregationBuffer, getTransitionsCount(mutableAggregationBuffer) + i);
        }

        public static int getTransversionsCount(Row row) {
            return row.getInt(3);
        }

        public static void setTransversionsCount(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(3, Integer.valueOf(i));
        }

        public static void addTransversionsCount(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            setTransversionsCount(mutableAggregationBuffer, getTransversionsCount(mutableAggregationBuffer) + i);
        }

        public static double getQualCount(Row row) {
            return row.getDouble(4);
        }

        public static void setQualCount(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            mutableAggregationBuffer.update(4, Double.valueOf(d));
        }

        public static void addQualCount(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            setQualCount(mutableAggregationBuffer, getQualCount(mutableAggregationBuffer) + d);
        }

        public static double getQualSum(Row row) {
            return row.getDouble(5);
        }

        public static void setQualSum(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            mutableAggregationBuffer.update(5, Double.valueOf(d));
        }

        public static void addQualSum(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            setQualSum(mutableAggregationBuffer, getQualSum(mutableAggregationBuffer) + d);
        }

        public static double getQualSumSq(Row row) {
            return row.getDouble(6);
        }

        public static void setQualSumSq(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            mutableAggregationBuffer.update(6, Double.valueOf(d));
        }

        public static void addQualSumSq(MutableAggregationBuffer mutableAggregationBuffer, double d) {
            setQualSumSq(mutableAggregationBuffer, getQualSumSq(mutableAggregationBuffer) + d);
        }

        public static Map<String, Integer> getVariantTypeCounts(Row row) {
            return row.getMap(7);
        }

        public static void setVariantTypeCounts(MutableAggregationBuffer mutableAggregationBuffer, Map<String, Integer> map) {
            mutableAggregationBuffer.update(7, map);
        }

        public static void addVariantTypeCounts(MutableAggregationBuffer mutableAggregationBuffer, String str, int i) {
            addToMap(mutableAggregationBuffer, str, i, 7);
        }

        public static Map<String, Integer> getVariantBiotypeCounts(Row row) {
            return row.getMap(8);
        }

        public static void setVariantBiotypeCounts(MutableAggregationBuffer mutableAggregationBuffer, Map<String, Integer> map) {
            mutableAggregationBuffer.update(8, map);
        }

        public static void addVariantBiotypeCounts(MutableAggregationBuffer mutableAggregationBuffer, String str, int i) {
            addToMap(mutableAggregationBuffer, str, i, 8);
        }

        public static Map<String, Integer> getConsequenceTypesCounts(Row row) {
            return row.getMap(9);
        }

        public static void setConsequenceTypesCounts(MutableAggregationBuffer mutableAggregationBuffer, Map<String, Integer> map) {
            mutableAggregationBuffer.update(9, map);
        }

        public static void addConsequenceTypesCounts(MutableAggregationBuffer mutableAggregationBuffer, String str, int i) {
            addToMap(mutableAggregationBuffer, str, i, 9);
        }

        public static Map<String, Integer> getByChromosomeCounts(Row row) {
            return row.getMap(10);
        }

        public static void setByChromosomeCounts(MutableAggregationBuffer mutableAggregationBuffer, Map<String, Integer> map) {
            mutableAggregationBuffer.update(10, map);
        }

        public static void addByChromosomeCounts(MutableAggregationBuffer mutableAggregationBuffer, String str, int i) {
            addToMap(mutableAggregationBuffer, str, i, 10);
        }

        private static void addToMap(MutableAggregationBuffer mutableAggregationBuffer, String str, int i, int i2) {
            Map map = mutableAggregationBuffer.getMap(i2);
            Option option = map.get(str);
            if (option.isDefined()) {
                i += ((Integer) option.get()).intValue();
            }
            mutableAggregationBuffer.update(i2, map.$plus(new Tuple2(str, Integer.valueOf(i))));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/VariantSetStatsTransformer$VariantSetStatsFunction.class */
    public static class VariantSetStatsFunction extends UserDefinedAggregateFunction {
        private final String studyId;
        private final Set<String> fileIds;

        VariantSetStatsFunction(String str, Collection<String> collection) {
            this.studyId = (str == null || str.isEmpty()) ? null : str;
            this.fileIds = (collection == null || collection.isEmpty()) ? null : new HashSet(collection);
        }

        public StructType inputSchema() {
            return DataTypes.createStructType(new StructField[]{DataTypes.createStructField("chromosome", DataTypes.StringType, false), DataTypes.createStructField("reference", DataTypes.StringType, false), DataTypes.createStructField("alternate", DataTypes.StringType, false), DataTypes.createStructField("type", DataTypes.StringType, false), DataTypes.createStructField("studies", DataTypes.createArrayType(VariantToRowConverter.STUDY_DATA_TYPE), false), DataTypes.createStructField("annotation", VariantToRowConverter.ANNOTATION_DATA_TYPE, true)});
        }

        public StructType bufferSchema() {
            return VariantSetStatsBufferUtils.VARIANT_SET_BUFFER_SCHEMA;
        }

        public DataType dataType() {
            return SchemaConverters.toSqlType(VariantSetStats.getClassSchema()).dataType();
        }

        public boolean deterministic() {
            return true;
        }

        public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
            VariantSetStatsBufferUtils.initialize(mutableAggregationBuffer);
        }

        public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            String string = row.getString(0);
            String string2 = row.getString(1);
            String string3 = row.getString(2);
            String string4 = row.getString(3);
            Seq<Row> seq = row.getSeq(4);
            Row struct = row.getStruct(5);
            VariantSetStatsBufferUtils.addNumVariants(mutableAggregationBuffer, 1);
            if (VariantStats.isTransition(string2, string3)) {
                VariantSetStatsBufferUtils.addTransitionsCount(mutableAggregationBuffer, 1);
            }
            if (VariantStats.isTransversion(string2, string3)) {
                VariantSetStatsBufferUtils.addTransversionsCount(mutableAggregationBuffer, 1);
            }
            VariantSetStatsBufferUtils.addByChromosomeCounts(mutableAggregationBuffer, string, 1);
            VariantSetStatsBufferUtils.addVariantTypeCounts(mutableAggregationBuffer, string4, 1);
            updateFromStudies(mutableAggregationBuffer, seq);
            updateFromAnnotation(mutableAggregationBuffer, struct);
        }

        private void updateFromStudies(MutableAggregationBuffer mutableAggregationBuffer, Seq<Row> seq) {
            Row row = null;
            if (!StringUtils.isEmpty(this.studyId)) {
                for (int i = 0; i < seq.length(); i++) {
                    Row row2 = (Row) seq.apply(i);
                    if (this.studyId.equals(row2.getString(row2.fieldIndex("studyId")))) {
                        row = row2;
                    }
                }
                if (row == null) {
                    return;
                }
            } else {
                if (seq.length() != 1) {
                    throw new IllegalArgumentException("Only 1 study expected. Found " + seq.length());
                }
                row = (Row) seq.apply(0);
            }
            Seq seq2 = row.getSeq(row.fieldIndex("files"));
            for (int i2 = 0; i2 < seq2.length(); i2++) {
                Row row3 = (Row) seq2.apply(i2);
                if (this.fileIds == null || this.fileIds.contains(row3.getString(row3.fieldIndex("fileId")))) {
                    Map map = row3.getMap(row3.fieldIndex("attributes"));
                    Option option = map.get("FILTER");
                    if (option.isDefined() && ((String) option.get()).equals("PASS")) {
                        VariantSetStatsBufferUtils.addNumPass(mutableAggregationBuffer, 1);
                    }
                    Option option2 = map.get("QUAL");
                    if (option2.isDefined() && !((String) option2.get()).isEmpty() && !((String) option2.get()).equals(".")) {
                        Double valueOf = Double.valueOf((String) option2.get());
                        VariantSetStatsBufferUtils.addQualCount(mutableAggregationBuffer, 1.0d);
                        VariantSetStatsBufferUtils.addQualSum(mutableAggregationBuffer, valueOf.doubleValue());
                        VariantSetStatsBufferUtils.addQualSumSq(mutableAggregationBuffer, valueOf.doubleValue() * valueOf.doubleValue());
                    }
                }
            }
        }

        private void updateFromAnnotation(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            if (row == null) {
                return;
            }
            HashSet hashSet = new HashSet();
            HashSet hashSet2 = new HashSet();
            Seq seq = row.getSeq(row.fieldIndex("consequenceTypes"));
            for (int i = 0; i < seq.length(); i++) {
                Row row2 = (Row) seq.apply(i);
                String string = row2.getString(row2.fieldIndex("biotype"));
                if (StringUtils.isNotEmpty(string)) {
                    hashSet.add(string);
                }
                Seq seq2 = row2.getSeq(row2.fieldIndex("sequenceOntologyTerms"));
                if (seq2 != null) {
                    for (int i2 = 0; i2 < seq2.length(); i2++) {
                        hashSet2.add(((Row) seq2.apply(i2)).getString(1));
                    }
                }
            }
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                VariantSetStatsBufferUtils.addVariantBiotypeCounts(mutableAggregationBuffer, (String) it.next(), 1);
            }
            Iterator it2 = hashSet2.iterator();
            while (it2.hasNext()) {
                VariantSetStatsBufferUtils.addConsequenceTypesCounts(mutableAggregationBuffer, (String) it2.next(), 1);
            }
        }

        public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            VariantSetStatsBufferUtils.merge(mutableAggregationBuffer, row);
        }

        public Object evaluate(Row row) {
            double qualSum = VariantSetStatsBufferUtils.getQualSum(row);
            double qualSumSq = VariantSetStatsBufferUtils.getQualSumSq(row);
            double qualCount = VariantSetStatsBufferUtils.getQualCount(row);
            double d = qualSum / qualCount;
            JavaConversions.mapAsJavaMap(VariantSetStatsBufferUtils.getByChromosomeCounts(row));
            return null;
        }
    }

    public VariantSetStatsTransformer() {
        this(null);
    }

    public VariantSetStatsTransformer(String str, String str2) {
        if (str != null) {
            setStudyId(str);
        }
        if (str2 != null) {
            setFileId(str2);
        }
    }

    public VariantSetStatsTransformer(String str) {
        super(str);
        setDefault(studyIdParam(), "");
        setDefault(fileIdParam(), "");
        setDefault(samplesParam(), Collections.emptyList());
    }

    @Override // org.opencb.oskar.spark.variant.transformers.params.HasStudyId
    public VariantSetStatsTransformer setStudyId(String str) {
        set(studyIdParam(), str);
        return this;
    }

    public Param<String> fileIdParam() {
        Param<String> param = this.fileIdParam == null ? new Param<>(this, "fileId", "") : this.fileIdParam;
        this.fileIdParam = param;
        return param;
    }

    public VariantSetStatsTransformer setFileId(String str) {
        set(fileIdParam(), str);
        return this;
    }

    public String getFileId() {
        return (String) getOrDefault(fileIdParam());
    }

    public Param<List<String>> samplesParam() {
        Param<List<String>> param = this.samplesParam == null ? new Param<>(this, "samples", "") : this.samplesParam;
        this.samplesParam = param;
        return param;
    }

    public VariantSetStatsTransformer setSamples(List<String> list) {
        set(samplesParam(), list);
        return this;
    }

    public List<String> getSamples() {
        return (List) getOrDefault(samplesParam());
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        int size;
        Dataset<?> dataset2 = dataset;
        VariantMetadataManager variantMetadataManager = new VariantMetadataManager();
        List<String> studies = variantMetadataManager.studies((Dataset<Row>) dataset2);
        String studyId = getStudyId();
        if (StringUtils.isEmpty(studyId)) {
            if (studies.size() != 1) {
                throw OskarException.missingStudy(studies);
            }
            studyId = studies.get(0);
        } else if (!studies.contains(studyId)) {
            throw OskarException.unknownStudy(studyId, studies);
        }
        ArrayList arrayList = new ArrayList();
        if (StringUtils.isNotEmpty(getFileId())) {
            arrayList.add(getFileId());
        }
        List<String> samples = getSamples();
        if (CollectionUtils.isNotEmpty(samples)) {
            Column column = null;
            Iterator<String> it = samples.iterator();
            while (it.hasNext()) {
                Column rlike = VariantUdfManager.genotype("studies", it.next()).rlike("1");
                column = column == null ? rlike : column.or(rlike);
            }
            dataset2 = dataset2.where(column);
            size = samples.size();
            if (arrayList.isEmpty()) {
                for (VariantFileMetadata variantFileMetadata : variantMetadataManager.variantStudyMetadata(dataset2, studyId).getFiles()) {
                    if (CollectionUtils.containsAny(variantFileMetadata.getSampleIds(), samples)) {
                        arrayList.add(variantFileMetadata.getId());
                        arrayList.add(variantFileMetadata.getPath());
                    }
                }
            }
        } else if (arrayList.isEmpty()) {
            size = variantMetadataManager.samples(dataset2, studyId).size();
        } else {
            ArrayList arrayList2 = new ArrayList();
            for (VariantFileMetadata variantFileMetadata2 : variantMetadataManager.variantStudyMetadata(dataset2, studyId).getFiles()) {
                if (arrayList.contains(variantFileMetadata2.getId()) || arrayList.contains(variantFileMetadata2.getPath())) {
                    arrayList2.addAll(variantFileMetadata2.getSampleIds());
                }
            }
            size = arrayList2.size();
        }
        if (CollectionUtils.isNotEmpty(arrayList)) {
            Column column2 = null;
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                Column isNotNull = VariantUdfManager.file("studies", (String) it2.next()).isNotNull();
                column2 = column2 == null ? isNotNull : column2.or(isNotNull);
            }
            dataset2 = dataset2.where(column2);
        }
        return dataset2.agg(new VariantSetStatsFunction(studyId, arrayList).apply(new Column[]{functions.col("chromosome"), functions.col("reference"), functions.col("alternate"), functions.col("type"), functions.col("studies"), functions.col("annotation")}).alias("stats"), new Column[0]).selectExpr(new String[]{"stats.*"}).withColumn("numSamples", functions.lit(Integer.valueOf(size)));
    }
}
