package org.opencb.oskar.spark.variant.transformers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.variant.AllelesCode;
import org.opencb.biodata.models.variant.Genotype;
import org.opencb.biodata.models.variant.stats.IdentityByState;
import org.opencb.biodata.tools.variant.algorithm.IdentityByStateClustering;
import org.opencb.oskar.core.exceptions.OskarException;
import org.opencb.oskar.spark.variant.VariantMetadataManager;
import org.opencb.oskar.spark.variant.converters.DataframeToFacetFieldConverter;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import scala.collection.Seq;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBSTransformer.class */
public class IBSTransformer extends AbstractTransformer {
    private final Param<List<String>> samplesParam;
    private final BooleanParam skipReferenceParam;
    private final BooleanParam skipMultiAllelicParam;
    private final IntParam numPairsParam;
    protected static final StructType RETURN_SCHEMA_TYPE = DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("samplePair", DataTypes.createArrayType(DataTypes.StringType), false), DataTypes.createStructField("distance", DataTypes.DoubleType, false), DataTypes.createStructField("counts", DataTypes.createArrayType(DataTypes.IntegerType, false), false), DataTypes.createStructField("variants", DataTypes.IntegerType, false), DataTypes.createStructField("skip", DataTypes.IntegerType, false)));

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBSTransformer$IdentityByStateAggregateFunction.class */
    private static class IdentityByStateAggregateFunction extends UserDefinedAggregateFunction {
        public static final int BUFFER_SCHEMA_SIZE = 5;
        public static final String DEFAULT_UNKNOWN_GENOTYPE = "./.";
        public static final Genotype MISS = new Genotype(DEFAULT_UNKNOWN_GENOTYPE);
        public static final Genotype REF = new Genotype("0/0");
        public static final Genotype HET = new Genotype("0/1");
        public static final Genotype ALT = new Genotype("1/1");
        public static final Genotype MULTI = new Genotype("1/2");
        private final int numPairs;
        private final int numSamples;
        private final List<Integer> samples;
        private final boolean skipReference;
        private final boolean skipMultiAllelic;
        private final Map<Integer, String> sampleIdxMap;

        /* JADX INFO: Access modifiers changed from: package-private */
        @FunctionalInterface
        /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBSTransformer$IdentityByStateAggregateFunction$SamplePairConsumer.class */
        public interface SamplePairConsumer {
            void accept(int i, int i2, int i3);
        }

        IdentityByStateAggregateFunction(int i, boolean z, boolean z2, List<Integer> list, Map<Integer, String> map) {
            this.numPairs = i;
            this.numSamples = list.size();
            this.samples = list;
            this.skipReference = z;
            this.skipMultiAllelic = z2;
            this.sampleIdxMap = map;
        }

        public StructType inputSchema() {
            return DataTypes.createStructType(new StructField[]{DataTypes.createStructField("startingPair", DataTypes.StringType, true), DataTypes.createStructField("samples", DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.StringType)), true)});
        }

        public StructType bufferSchema() {
            ArrayList arrayList = new ArrayList((5 * this.numPairs) + 1);
            for (int i = 0; i < this.numPairs; i++) {
                arrayList.add(DataTypes.createStructField("ibs1_" + i, DataTypes.IntegerType, false));
                arrayList.add(DataTypes.createStructField("ibs2_" + i, DataTypes.IntegerType, false));
                arrayList.add(DataTypes.createStructField("ibs3_" + i, DataTypes.IntegerType, false));
                arrayList.add(DataTypes.createStructField("variants_" + i, DataTypes.IntegerType, false));
                arrayList.add(DataTypes.createStructField("skip_" + i, DataTypes.IntegerType, false));
            }
            arrayList.add(DataTypes.createStructField("startingPair", DataTypes.StringType, false));
            return DataTypes.createStructType(arrayList);
        }

        public DataType dataType() {
            return DataTypes.createArrayType(structDataType());
        }

        private StructType structDataType() {
            return IBSTransformer.RETURN_SCHEMA_TYPE;
        }

        public boolean deterministic() {
            return true;
        }

        public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
            for (int i = 0; i < this.numPairs * 5; i++) {
                mutableAggregationBuffer.update(i, 0);
            }
            mutableAggregationBuffer.update(this.numPairs * 5, (Object) null);
        }

        public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            String string = row.getString(0);
            mutableAggregationBuffer.update(this.numPairs * 5, string);
            Seq seq = row.getSeq(1);
            iterate(string, (i, i2, i3) -> {
                updatePair(mutableAggregationBuffer, i3, (String) ((Seq) seq.apply(i)).apply(0), (String) ((Seq) seq.apply(i2)).apply(0));
            });
        }

        public void iterate(String str, SamplePairConsumer samplePairConsumer) {
            if (str == null) {
                return;
            }
            int indexOf = str.indexOf(44);
            int intValue = Integer.valueOf(str.substring(0, indexOf)).intValue();
            int intValue2 = Integer.valueOf(str.substring(indexOf + 1)).intValue();
            int i = 0;
            for (int i2 = intValue; i2 < this.numSamples; i2++) {
                for (int i3 = intValue2 >= 0 ? intValue2 : i2 + 1; i3 < this.numSamples; i3++) {
                    intValue2 = -1;
                    samplePairConsumer.accept(this.samples.get(i2).intValue(), this.samples.get(i3).intValue(), i);
                    i++;
                    if (i == this.numPairs) {
                        return;
                    }
                }
            }
        }

        private boolean updatePair(MutableAggregationBuffer mutableAggregationBuffer, int i, String str, String str2) {
            boolean z;
            IdentityByStateClustering identityByStateClustering = new IdentityByStateClustering();
            if (str == null || str.isEmpty() || str2 == null || str2.isEmpty()) {
                z = true;
            } else {
                Genotype buildGenotype = buildGenotype(str);
                Genotype buildGenotype2 = buildGenotype(str2);
                if (buildGenotype.getPloidy() != 2 || buildGenotype2.getPloidy() != 2) {
                    z = true;
                } else if (anyGtMissing(buildGenotype, buildGenotype2)) {
                    z = true;
                } else if (this.skipReference && allReference(buildGenotype, buildGenotype2)) {
                    z = true;
                } else if (this.skipMultiAllelic && anyGtMiltuallelic(buildGenotype, buildGenotype2)) {
                    z = true;
                } else {
                    updateSharedAllelesCount(mutableAggregationBuffer, i, identityByStateClustering.countSharedAlleles(buildGenotype.getPloidy(), buildGenotype, buildGenotype2));
                    z = false;
                }
            }
            if (z) {
                skip(mutableAggregationBuffer, i);
            } else {
                ok(mutableAggregationBuffer, i);
            }
            return z;
        }

        private boolean allReference(Genotype genotype, Genotype genotype2) {
            int[] allelesIdx = genotype.getAllelesIdx();
            int[] allelesIdx2 = genotype2.getAllelesIdx();
            return allelesIdx[0] == 0 && allelesIdx[1] == 0 && allelesIdx2[0] == 0 && allelesIdx2[1] == 0;
        }

        private boolean anyGtMiltuallelic(Genotype genotype, Genotype genotype2) {
            return genotype.getCode() == AllelesCode.MULTIPLE_ALTERNATES || genotype2.getCode() == AllelesCode.MULTIPLE_ALTERNATES;
        }

        private boolean anyGtMissing(Genotype genotype, Genotype genotype2) {
            return genotype.getCode() == AllelesCode.ALLELES_MISSING || genotype2.getCode() == AllelesCode.ALLELES_MISSING;
        }

        private Genotype buildGenotype(String str) {
            boolean z = -1;
            switch (str.hashCode()) {
                case 46:
                    if (str.equals(".")) {
                        z = 10;
                        break;
                    }
                    break;
                case 45709:
                    if (str.equals(DEFAULT_UNKNOWN_GENOTYPE)) {
                        z = 9;
                        break;
                    }
                    break;
                case 47633:
                    if (str.equals("0/0")) {
                        z = false;
                        break;
                    }
                    break;
                case 47634:
                    if (str.equals("0/1")) {
                        z = 2;
                        break;
                    }
                    break;
                case 48595:
                    if (str.equals("1/1")) {
                        z = 5;
                        break;
                    }
                    break;
                case 48596:
                    if (str.equals("1/2")) {
                        z = 7;
                        break;
                    }
                    break;
                case 50020:
                    if (str.equals("0|0")) {
                        z = true;
                        break;
                    }
                    break;
                case 50021:
                    if (str.equals("0|1")) {
                        z = 3;
                        break;
                    }
                    break;
                case 50981:
                    if (str.equals("1|0")) {
                        z = 4;
                        break;
                    }
                    break;
                case 50982:
                    if (str.equals("1|1")) {
                        z = 6;
                        break;
                    }
                    break;
                case 62063:
                    if (str.equals("?/?")) {
                        z = 8;
                        break;
                    }
                    break;
            }
            switch (z) {
                case SampleVariantStatsTransformer.BufferUtils.SAMPLE_INDEX /* 0 */:
                case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                    return REF;
                case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                case SampleVariantStatsTransformer.BufferUtils.TYPE_COUNT_INDEX /* 3 */:
                case SampleVariantStatsTransformer.BufferUtils.GENOTYPE_COUNT_INDEX /* 4 */:
                    return HET;
                case true:
                case SampleVariantStatsTransformer.BufferUtils.NUM_PASS_INDEX /* 6 */:
                    return ALT;
                case SampleVariantStatsTransformer.BufferUtils.TRANSITIONS_INDEX /* 7 */:
                    return MULTI;
                case SampleVariantStatsTransformer.BufferUtils.TRANSVERSIONS_INDEX /* 8 */:
                case SampleVariantStatsTransformer.BufferUtils.TI_TV_RATIO_INDEX /* 9 */:
                case SampleVariantStatsTransformer.BufferUtils.QUALITY_COUNT_INDEX /* 10 */:
                    return MISS;
                default:
                    return new Genotype(str);
            }
        }

        private void updateSharedAllelesCount(MutableAggregationBuffer mutableAggregationBuffer, int i, int i2) {
            mutableAggregationBuffer.update(i2 + getBufferOffset(i), Integer.valueOf(mutableAggregationBuffer.getInt(i2 + getBufferOffset(i)) + 1));
        }

        private static String getGt1(Row row, int i) {
            return row.getString(i * 2);
        }

        private static String getGt2(Row row, int i) {
            return row.getString((i * 2) + 1);
        }

        private static void ok(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(3 + getBufferOffset(i), Integer.valueOf(mutableAggregationBuffer.getInt(3 + getBufferOffset(i)) + 1));
        }

        private static void skip(MutableAggregationBuffer mutableAggregationBuffer, int i) {
            mutableAggregationBuffer.update(4 + getBufferOffset(i), Integer.valueOf(mutableAggregationBuffer.getInt(4 + getBufferOffset(i)) + 1));
        }

        public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
            for (int i = 0; i < this.numPairs * 5; i++) {
                mutableAggregationBuffer.update(i, Integer.valueOf(mutableAggregationBuffer.getInt(i) + row.getInt(i)));
            }
            if (mutableAggregationBuffer.get(this.numPairs * 5) == null) {
                mutableAggregationBuffer.update(this.numPairs * 5, row.get(this.numPairs * 5));
            }
        }

        public Object evaluate(Row row) {
            new IdentityByStateClustering();
            IdentityByState identityByState = new IdentityByState();
            GenericRowWithSchema[] genericRowWithSchemaArr = new GenericRowWithSchema[this.numPairs];
            iterate(row.getString(this.numPairs * 5), (i, i2, i3) -> {
                int bufferOffset = getBufferOffset(i3);
                identityByState.ibs = new int[]{row.getInt(bufferOffset), row.getInt(bufferOffset + 1), row.getInt(bufferOffset + 2)};
                genericRowWithSchemaArr[i3] = new GenericRowWithSchema(new Object[]{new String[]{this.sampleIdxMap.get(Integer.valueOf(i)), this.sampleIdxMap.get(Integer.valueOf(i2))}, Double.valueOf(identityByState.getDistance()), identityByState.ibs, Integer.valueOf(row.getInt(3 + bufferOffset)), Integer.valueOf(row.getInt(4 + bufferOffset))}, structDataType());
            });
            return genericRowWithSchemaArr[genericRowWithSchemaArr.length - 1] == null ? ArrayUtils.subarray(genericRowWithSchemaArr, 0, ArrayUtils.indexOf(genericRowWithSchemaArr, (Object) null)) : genericRowWithSchemaArr;
        }

        private static int getBufferOffset(int i) {
            return i * 5;
        }
    }

    public IBSTransformer() {
        this(null);
    }

    public IBSTransformer(String str) {
        super(str);
        this.samplesParam = new Param<>(this, "samples", "List of samples to use for calculating the IBS");
        this.skipMultiAllelicParam = new BooleanParam(this, "skipMultiAllelic", "Skip variants where any of the samples has a secondary alternate");
        this.skipReferenceParam = new BooleanParam(this, "skipReference", "Skip variants where both samples of the pair are HOM_REF");
        this.numPairsParam = new IntParam(this, "numPairs", "");
        setDefault(samplesParam(), Collections.emptyList());
        setDefault(skipReferenceParam(), false);
        setDefault(skipMultiAllelicParam(), true);
        setDefault(numPairsParam(), 5);
    }

    public Param<List<String>> samplesParam() {
        return this.samplesParam;
    }

    public IBSTransformer setSamples(List<String> list) {
        set(this.samplesParam, list);
        return this;
    }

    public IBSTransformer setSamples(String... strArr) {
        set(this.samplesParam, Arrays.asList(strArr));
        return this;
    }

    public List<String> getSamples() {
        return (List) getOrDefault(this.samplesParam);
    }

    public BooleanParam skipReferenceParam() {
        return this.skipReferenceParam;
    }

    public IBSTransformer setSkipReference(boolean z) {
        set(this.skipReferenceParam, Boolean.valueOf(z));
        return this;
    }

    public boolean getSkipReference() {
        return ((Boolean) getOrDefault(this.skipReferenceParam)).booleanValue();
    }

    public BooleanParam skipMultiAllelicParam() {
        return this.skipMultiAllelicParam;
    }

    public IBSTransformer setSkipMultiAllelic(boolean z) {
        set(this.skipMultiAllelicParam, Boolean.valueOf(z));
        return this;
    }

    public boolean getSkipMultiAllelic() {
        return ((Boolean) getOrDefault(this.skipMultiAllelicParam)).booleanValue();
    }

    public IntParam numPairsParam() {
        return this.numPairsParam;
    }

    public IBSTransformer setNumPairs(int i) {
        set(this.numPairsParam, Integer.valueOf(i));
        return this;
    }

    public int getNumPairs() {
        return ((Integer) getOrDefault(this.numPairsParam)).intValue();
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        int numPairs = getNumPairs();
        List<String> samples = getSamples();
        boolean skipReference = getSkipReference();
        boolean skipMultiAllelic = getSkipMultiAllelic();
        Map.Entry<String, List<String>> next = new VariantMetadataManager().samples(dataset).entrySet().iterator().next();
        String key = next.getKey();
        List<String> value = next.getValue();
        ArrayList arrayList = new ArrayList(samples.size());
        HashMap hashMap = new HashMap(samples.size());
        if (samples.isEmpty()) {
            samples = value;
        }
        for (String str : samples) {
            int indexOf = value.indexOf(str);
            if (indexOf < 0) {
                throw OskarException.unknownSample(key, str, value);
            }
            arrayList.add(Integer.valueOf(indexOf));
            hashMap.put(Integer.valueOf(indexOf), str);
        }
        IdentityByStateAggregateFunction identityByStateAggregateFunction = new IdentityByStateAggregateFunction(numPairs, skipReference, skipMultiAllelic, arrayList, hashMap);
        int size = arrayList.size();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(arrayList.get(0) + DataframeToFacetFieldConverter.INCLUDE_SEPARATOR + arrayList.get(1));
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            for (int i3 = i2 + 1; i3 < size; i3++) {
                if (i == numPairs) {
                    arrayList2.add(arrayList.get(i2) + DataframeToFacetFieldConverter.INCLUDE_SEPARATOR + arrayList.get(i3));
                    i = 0;
                }
                i++;
            }
        }
        return (dataset.sparkSession().sparkContext().version().startsWith("2.0") ? dataset.withColumn("startingPairs", functions.lit(String.join("_", arrayList2))).withColumn("startingPair", functions.explode(functions.split(functions.col("startingPairs"), "_"))) : dataset.withColumn("startingPair", functions.lit(arrayList2.toArray(new String[0]))).withColumn("startingPair", functions.explode(functions.col("startingPair")))).select(new Column[]{functions.col("startingPair"), functions.col("studies").getItem(0).getField("samplesData").as("samples")}).groupBy("startingPair", new String[0]).agg(identityByStateAggregateFunction.apply(new Column[]{functions.col("startingPair"), functions.col("samples")}).alias("ibs"), new Column[0]).withColumn("ibs", functions.explode(functions.col("ibs"))).select("ibs.*", new String[0]).orderBy("samplePair", new String[0]);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        return RETURN_SCHEMA_TYPE;
    }
}
