package org.opencb.oskar.spark.variant.transformers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.FlatMapGroupsFunction;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.variant.Genotype;
import org.opencb.biodata.tools.pedigree.MendelianError;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import org.opencb.oskar.spark.variant.transformers.params.HasStudyId;
import org.opencb.oskar.spark.variant.udf.StudyFunction;
import scala.collection.mutable.WrappedArray;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/PCATransformer.class */
public class PCATransformer extends AbstractTransformer implements HasStudyId {
    private final IntParam kParam;
    private final String GENOTYPES_COLUMN_NAME = "genotypes";
    private final String GENOTYPE_COLUMN_NAME = "genotype";
    private final String ROW_COLUMN_NAME = "rowIndex";
    private final String COLUMN_COLUMN_NAME = "colIndex";
    private final String PCA_COLUMN_NAME = "PCA";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.opencb.oskar.spark.variant.transformers.PCATransformer$5, reason: invalid class name */
    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/PCATransformer$5.class */
    public static /* synthetic */ class AnonymousClass5 {
        static final /* synthetic */ int[] $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode = new int[MendelianError.GenotypeCode.values().length];

        static {
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HET.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HOM_VAR.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.GenotypeCode.HOM_REF.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public PCATransformer() {
        this(null);
    }

    public PCATransformer(String str) {
        super(str);
        this.GENOTYPES_COLUMN_NAME = "genotypes";
        this.GENOTYPE_COLUMN_NAME = "genotype";
        this.ROW_COLUMN_NAME = "rowIndex";
        this.COLUMN_COLUMN_NAME = "colIndex";
        this.PCA_COLUMN_NAME = "PCA";
        this.kParam = new IntParam(this, "k", "");
        setDefault(kParam(), 2);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.params.HasStudyId
    public PCATransformer setStudyId(String str) {
        set(studyIdParam(), str);
        return this;
    }

    public IntParam kParam() {
        return this.kParam;
    }

    public PCATransformer setK(int i) {
        set(this.kParam, Integer.valueOf(i));
        return this;
    }

    public int getK() {
        return ((Integer) getOrDefault(this.kParam)).intValue();
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset flatMapGroups = dataset.map(new MapFunction<Row, Row>() { // from class: org.opencb.oskar.spark.variant.transformers.PCATransformer.4
            public Row call(Row row) {
                double d;
                GenericRowWithSchema apply = new StudyFunction().apply((WrappedArray<? extends Row>) row.apply(row.fieldIndex("studies")), PCATransformer.this.getStudyId());
                List list = apply.getList(apply.fieldIndex("samplesData"));
                double[] dArr = new double[list.size()];
                for (int i = 0; i < list.size(); i++) {
                    switch (AnonymousClass5.$SwitchMap$org$opencb$biodata$tools$pedigree$MendelianError$GenotypeCode[MendelianError.getAlternateAlleleCount(new Genotype((String) ((WrappedArray) list.get(i)).apply(0))).ordinal()]) {
                        case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                            d = 1.0d;
                            break;
                        case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                            d = 2.0d;
                            break;
                        case SampleVariantStatsTransformer.BufferUtils.TYPE_COUNT_INDEX /* 3 */:
                        default:
                            d = 0.0d;
                            break;
                    }
                    dArr[i] = d;
                }
                return RowFactory.create(new Object[]{Vectors.dense(dArr)});
            }
        }, RowEncoder.apply(createSchema("genotypes"))).withColumn("rowIndex", functions.row_number().over(Window.orderBy("genotypes", new String[0]))).flatMap(new FlatMapFunction<Row, Row>() { // from class: org.opencb.oskar.spark.variant.transformers.PCATransformer.3
            public Iterator<Row> call(Row row) throws Exception {
                ArrayList arrayList = new ArrayList();
                DenseVector denseVector = (DenseVector) row.get(row.fieldIndex("genotypes"));
                for (int i = 0; i < denseVector.size(); i++) {
                    arrayList.add(RowFactory.create(new Object[]{Integer.valueOf(row.getInt(row.fieldIndex("rowIndex"))), Integer.valueOf(i), Double.valueOf(denseVector.values()[i])}));
                }
                return arrayList.iterator();
            }
        }, RowEncoder.apply(createSchema3())).groupByKey(new MapFunction<Row, Integer>() { // from class: org.opencb.oskar.spark.variant.transformers.PCATransformer.2
            public Integer call(Row row) throws Exception {
                return Integer.valueOf(row.getInt(1));
            }
        }, Encoders.INT()).flatMapGroups(new FlatMapGroupsFunction<Integer, Row, Row>() { // from class: org.opencb.oskar.spark.variant.transformers.PCATransformer.1
            public Iterator<Row> call(Integer num, Iterator<Row> it) throws Exception {
                HashMap hashMap = new HashMap();
                while (it.hasNext()) {
                    Row next = it.next();
                    hashMap.put(Integer.valueOf(next.getInt(0)), Double.valueOf(next.getDouble(2)));
                }
                double[] dArr = new double[hashMap.size()];
                for (Integer num2 : hashMap.keySet()) {
                    dArr[num2.intValue() - 1] = ((Double) hashMap.get(num2)).doubleValue();
                }
                return Collections.singletonList(RowFactory.create(new Object[]{Vectors.dense(dArr)})).iterator();
            }

            public /* bridge */ /* synthetic */ Iterator call(Object obj, Iterator it) throws Exception {
                return call((Integer) obj, (Iterator<Row>) it);
            }
        }, RowEncoder.apply(createSchema("genotypes")));
        return new PCA().setInputCol("genotypes").setOutputCol("PCA").setK(getK()).fit(flatMapGroups).transform(flatMapGroups).select("PCA", new String[0]);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        return createSchema("PCA");
    }

    private StructType createSchema(String str) {
        return new StructType(new StructField[]{new StructField(str, new VectorUDT(), false, Metadata.empty())});
    }

    private StructType createSchema3() {
        return new StructType(new StructField[]{new StructField("rowIndex", DataTypes.IntegerType, false, Metadata.empty()), new StructField("colIndex", DataTypes.IntegerType, false, Metadata.empty()), new StructField("genotype", DataTypes.DoubleType, false, Metadata.empty())});
    }
}
