package org.opencb.oskar.spark.variant.transformers;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.clinical.Phenotype;
import org.opencb.biodata.models.clinical.pedigree.Member;
import org.opencb.biodata.models.clinical.pedigree.Pedigree;
import org.opencb.biodata.models.metadata.Cohort;
import org.opencb.biodata.models.variant.Genotype;
import org.opencb.biodata.models.variant.metadata.VariantStudyMetadata;
import org.opencb.biodata.tools.pedigree.MendelianError;
import org.opencb.commons.utils.CollectionUtils;
import org.opencb.commons.utils.ListUtils;
import org.opencb.oskar.analysis.stats.ChiSquareTest;
import org.opencb.oskar.analysis.stats.ChiSquareTestResult;
import org.opencb.oskar.analysis.stats.FisherExactTest;
import org.opencb.oskar.analysis.stats.FisherTestResult;
import org.opencb.oskar.spark.variant.Oskar;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import org.opencb.oskar.spark.variant.transformers.params.HasStudyId;
import org.opencb.oskar.spark.variant.udf.StudyFunction;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.WrappedArray;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/GwasTransformer.class */
public class GwasTransformer extends AbstractTransformer implements HasStudyId {
    public static final String GWAS_COL_NAME = "gwas_stats";
    private final Param<List<String>> sampleList1Param;
    private final Param<List<String>> sampleList2Param;
    private final Param<String> phenotype1Param;
    private final Param<String> phenotype2Param;
    private final Param<String> cohort1Param;
    private final Param<String> cohort2Param;
    private final Param<String> methodParam;
    private final Param<String> fisherModeParam;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.opencb.oskar.spark.variant.transformers.GwasTransformer$1, reason: invalid class name */
    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/GwasTransformer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode = new int[MendelianError.GenotypeCode.values().length];

        static {
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HOM_REF.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HET.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HOM_VAR.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/GwasTransformer$ChiSquareFunction.class */
    public static class ChiSquareFunction extends AbstractFunction1<WrappedArray<GenericRowWithSchema>, Row> implements Serializable {
        private final String studyId;
        private final Set<Integer> affectedIndexSet;
        private final Set<Integer> unaffectedIndexSet;

        public ChiSquareFunction(String str, Set<Integer> set, Set<Integer> set2) {
            this.studyId = str;
            this.affectedIndexSet = set;
            this.unaffectedIndexSet = set2;
        }

        public Row apply(WrappedArray<GenericRowWithSchema> wrappedArray) {
            int[] computeCounts = GwasTransformer.computeCounts(new StudyFunction().apply((WrappedArray<? extends Row>) wrappedArray, this.studyId), this.affectedIndexSet, this.unaffectedIndexSet);
            ChiSquareTestResult chiSquareTest = ChiSquareTest.chiSquareTest(computeCounts[0], computeCounts[1], computeCounts[2], computeCounts[3]);
            return RowFactory.create(new Object[]{Double.valueOf(chiSquareTest.getChiSquare()), Double.valueOf(chiSquareTest.getpValue()), Double.valueOf(chiSquareTest.getOddRatio())});
        }
    }

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/GwasTransformer$FisherFunction.class */
    public static class FisherFunction extends AbstractFunction1<WrappedArray<GenericRowWithSchema>, Row> implements Serializable {
        private final String studyId;
        private final int mode;
        private final Set<Integer> affectedIndexSet;
        private final Set<Integer> unaffectedIndexSet;

        public FisherFunction(String str, int i, Set<Integer> set, Set<Integer> set2) {
            this.studyId = str;
            this.mode = i;
            this.affectedIndexSet = set;
            this.unaffectedIndexSet = set2;
        }

        public Row apply(WrappedArray<GenericRowWithSchema> wrappedArray) {
            int[] computeCounts = GwasTransformer.computeCounts(new StudyFunction().apply((WrappedArray<? extends Row>) wrappedArray, this.studyId), this.affectedIndexSet, this.unaffectedIndexSet);
            FisherTestResult fisherTest = new FisherExactTest().fisherTest(computeCounts[0], computeCounts[1], computeCounts[2], computeCounts[3], this.mode);
            return RowFactory.create(new Object[]{Double.valueOf(fisherTest.getpValue()), Double.valueOf(fisherTest.getOddRatio())});
        }
    }

    public GwasTransformer() {
        this(null);
    }

    public GwasTransformer(String str) {
        super(str);
        this.sampleList1Param = new Param<>(this, "sampleList1", "Sample list 1");
        this.sampleList2Param = new Param<>(this, "sampleList2", "Sample list 2");
        this.phenotype1Param = new Param<>(this, "phenotype1", "Phenotype 1");
        this.phenotype2Param = new Param<>(this, "phenotype2", "Phenotype 2");
        this.cohort1Param = new Param<>(this, "cohort1", "Cohort 1");
        this.cohort2Param = new Param<>(this, "cohort2", "Cohort 2");
        this.methodParam = new Param<>(this, "method", "Method: fisher or chi-square");
        this.fisherModeParam = new Param<>(this, "fisherMode", "Fisher exact test mode");
        setDefault(sampleList1Param(), Collections.emptyList());
        setDefault(sampleList2Param(), Collections.emptyList());
        setDefault(phenotype1Param(), "");
        setDefault(phenotype2Param(), "");
        setDefault(cohort1Param(), "");
        setDefault(cohort2Param(), "");
        setDefault(methodParam(), "fisher");
        setDefault(fisherModeParam(), "two-side");
    }

    @Override // org.opencb.oskar.spark.variant.transformers.params.HasStudyId
    public GwasTransformer setStudyId(String str) {
        set(studyIdParam(), str);
        return this;
    }

    public Param<List<String>> sampleList1Param() {
        return this.sampleList1Param;
    }

    public GwasTransformer setSampleList1(List<String> list) {
        set(this.sampleList1Param, list);
        return this;
    }

    public List<String> getSampleList1() {
        return (List) getOrDefault(this.sampleList1Param);
    }

    public Param<List<String>> sampleList2Param() {
        return this.sampleList2Param;
    }

    public GwasTransformer setSampleList2(List<String> list) {
        set(this.sampleList2Param, list);
        return this;
    }

    public List<String> getSampleList2() {
        return (List) getOrDefault(this.sampleList2Param);
    }

    public Param<String> phenotype1Param() {
        return this.phenotype1Param;
    }

    public GwasTransformer setPhenotype1(String str) {
        set(this.phenotype1Param, str);
        return this;
    }

    public String getPhenotype1() {
        return (String) getOrDefault(this.phenotype1Param);
    }

    public Param<String> phenotype2Param() {
        return this.phenotype2Param;
    }

    public GwasTransformer setPhenotype2(String str) {
        set(this.phenotype2Param, str);
        return this;
    }

    public String getPhenotype2() {
        return (String) getOrDefault(this.phenotype2Param);
    }

    public Param<String> cohort1Param() {
        return this.cohort1Param;
    }

    public GwasTransformer setCohort1(String str) {
        set(this.cohort1Param, str);
        return this;
    }

    public String getCohort1() {
        return (String) getOrDefault(this.cohort1Param);
    }

    public Param<String> cohort2Param() {
        return this.cohort2Param;
    }

    public GwasTransformer setCohort2(String str) {
        set(this.cohort2Param, str);
        return this;
    }

    public String getCohort2() {
        return (String) getOrDefault(this.cohort2Param);
    }

    public Param<String> methodParam() {
        return this.methodParam;
    }

    public GwasTransformer setMethod(String str) {
        set(this.methodParam, str);
        return this;
    }

    public String getMethod() {
        return (String) getOrDefault(this.methodParam);
    }

    public Param<String> fisherModeParam() {
        return this.fisherModeParam;
    }

    public GwasTransformer setFisherMode(String str) {
        set(this.fisherModeParam, str);
        return this;
    }

    public String getFisherMode() {
        return (String) getOrDefault(this.fisherModeParam);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        List<String> samples = new Oskar().metadata().samples(dataset, getStudyId());
        if (StringUtils.isNotEmpty(getPhenotype1()) || StringUtils.isNotEmpty(getPhenotype2())) {
            Iterator<Pedigree> it = new Oskar().metadata().pedigrees(dataset, getStudyId()).iterator();
            while (it.hasNext()) {
                for (Member member : it.next().getMembers()) {
                    if (ListUtils.isNotEmpty(member.getPhenotypes())) {
                        Iterator it2 = member.getPhenotypes().iterator();
                        while (true) {
                            if (it2.hasNext()) {
                                Phenotype phenotype = (Phenotype) it2.next();
                                if (!StringUtils.isNotEmpty(getPhenotype1()) || !getPhenotype1().equals(phenotype.getId())) {
                                    if (StringUtils.isNotEmpty(getPhenotype2()) && getPhenotype2().equals(phenotype.getId())) {
                                        hashSet2.add(Integer.valueOf(samples.indexOf(member.getId())));
                                        break;
                                    }
                                } else {
                                    hashSet.add(Integer.valueOf(samples.indexOf(member.getId())));
                                    break;
                                }
                            }
                        }
                    }
                }
            }
            if (StringUtils.isEmpty(getPhenotype1()) && CollectionUtils.isNotEmpty(hashSet2)) {
                populateTargetIndexSet(samples, hashSet2, hashSet);
            } else if (StringUtils.isEmpty(getPhenotype2()) && CollectionUtils.isNotEmpty(hashSet)) {
                populateTargetIndexSet(samples, hashSet, hashSet2);
            }
        } else if (CollectionUtils.isNotEmpty(getSampleList1()) || CollectionUtils.isNotEmpty(getSampleList2())) {
            getSampleList1().forEach(str -> {
                hashSet.add(Integer.valueOf(samples.indexOf(str)));
            });
            getSampleList2().forEach(str2 -> {
                hashSet2.add(Integer.valueOf(samples.indexOf(str2)));
            });
            if (CollectionUtils.isEmpty(getSampleList1()) && CollectionUtils.isNotEmpty(hashSet2)) {
                populateTargetIndexSet(samples, hashSet2, hashSet);
            } else if (CollectionUtils.isEmpty(getSampleList2()) && CollectionUtils.isNotEmpty(hashSet)) {
                populateTargetIndexSet(samples, hashSet, hashSet2);
            }
        } else if (StringUtils.isNotEmpty(getCohort1()) || StringUtils.isNotEmpty(getCohort2())) {
            List<Cohort> list = null;
            Iterator it3 = new Oskar().metadata().variantMetadata(dataset).getStudies().iterator();
            while (true) {
                if (!it3.hasNext()) {
                    break;
                }
                VariantStudyMetadata variantStudyMetadata = (VariantStudyMetadata) it3.next();
                if (variantStudyMetadata.getId().equals(getStudyId())) {
                    list = variantStudyMetadata.getCohorts();
                    break;
                }
            }
            if (CollectionUtils.isNotEmpty(list)) {
                for (Cohort cohort : list) {
                    if (cohort.getId().equals(getCohort1())) {
                        cohort.getSampleIds().forEach(str3 -> {
                            hashSet.add(Integer.valueOf(samples.indexOf(str3)));
                        });
                    } else if (cohort.getId().equals(getCohort2())) {
                        cohort.getSampleIds().forEach(str4 -> {
                            hashSet2.add(Integer.valueOf(samples.indexOf(str4)));
                        });
                    }
                }
                if (StringUtils.isEmpty(getCohort1()) && CollectionUtils.isNotEmpty(hashSet2)) {
                    populateTargetIndexSet(samples, hashSet2, hashSet);
                } else if (StringUtils.isEmpty(getCohort2()) && CollectionUtils.isNotEmpty(hashSet)) {
                    populateTargetIndexSet(samples, hashSet, hashSet2);
                }
            }
        }
        UserDefinedFunction userDefinedFunction = null;
        if ("fisher-test".equals(getMethod())) {
            userDefinedFunction = functions.udf(new FisherFunction(getStudyId(), 3, hashSet, hashSet2), fisherSchema());
        } else if ("chi-square-test".equals(getMethod())) {
            userDefinedFunction = functions.udf(new ChiSquareFunction(getStudyId(), hashSet, hashSet2), chiSquareSchema());
        }
        return dataset.withColumn(GWAS_COL_NAME, userDefinedFunction.apply(new ListBuffer().$plus$eq(functions.col("studies")))).selectExpr(new String[]{"*", "gwas_stats.*"}).drop(GWAS_COL_NAME);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        List list = (List) Arrays.stream(structType.fields()).collect(Collectors.toList());
        if ("fisher-test".equals(getMethod())) {
            list.add(DataTypes.createStructField(GWAS_COL_NAME, fisherSchema(), false));
        } else if ("chi-square-test".equals(getMethod())) {
            list.add(DataTypes.createStructField(GWAS_COL_NAME, chiSquareSchema(), false));
        }
        return DataTypes.createStructType(list);
    }

    private StructType fisherSchema() {
        return DataTypes.createStructType(new StructField[]{DataTypes.createStructField("pValue", DataTypes.DoubleType, false), DataTypes.createStructField("oddRatio", DataTypes.DoubleType, false)});
    }

    private StructType chiSquareSchema() {
        return DataTypes.createStructType(new StructField[]{DataTypes.createStructField("chiSquare", DataTypes.DoubleType, false), DataTypes.createStructField("pValue", DataTypes.DoubleType, false), DataTypes.createStructField("oddRatio", DataTypes.DoubleType, false)});
    }

    private void populateTargetIndexSet(List<String> list, Set<Integer> set, Set<Integer> set2) {
        for (int i = 0; i < list.size(); i++) {
            if (!set.contains(Integer.valueOf(i))) {
                set2.add(Integer.valueOf(i));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] computeCounts(GenericRowWithSchema genericRowWithSchema, Set<Integer> set, Set<Integer> set2) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        List list = genericRowWithSchema.getList(genericRowWithSchema.fieldIndex("samplesData"));
        for (int i5 = 0; i5 < list.size(); i5++) {
            switch (AnonymousClass1.$SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.getAlternateAlleleCount(new Genotype((String) ((WrappedArray) list.get(i5)).apply(0))).ordinal()]) {
                case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                    if (set.contains(Integer.valueOf(i5))) {
                        i += 2;
                        break;
                    } else if (set2.contains(Integer.valueOf(i5))) {
                        i2 += 2;
                        break;
                    } else {
                        break;
                    }
                case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                    if (set.contains(Integer.valueOf(i5))) {
                        i++;
                        i3++;
                        break;
                    } else if (set2.contains(Integer.valueOf(i5))) {
                        i2++;
                        i4++;
                        break;
                    } else {
                        break;
                    }
                case SampleVariantStatsTransformer.BufferUtils.TYPE_COUNT_INDEX /* 3 */:
                    if (set.contains(Integer.valueOf(i5))) {
                        i3 += 2;
                        break;
                    } else if (set2.contains(Integer.valueOf(i5))) {
                        i4 += 2;
                        break;
                    } else {
                        break;
                    }
            }
        }
        return new int[]{i, i2, i3, i4};
    }
}
