package org.opencb.oskar.spark.variant.transformers;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.opencb.biodata.models.variant.stats.IBDExpectedFrequencies;
import org.opencb.oskar.spark.variant.converters.RowToVariantConverter;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import scala.collection.mutable.WrappedArray;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBDTransformer.class */
public class IBDTransformer extends IBSTransformer {

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBDTransformer$PiHatFunction.class */
    public static class PiHatFunction extends AbstractFunction1<WrappedArray<Double>, Double> implements Serializable {
        public Double apply(WrappedArray<Double> wrappedArray) {
            return Double.valueOf((((Double) wrappedArray.apply(1)).doubleValue() / 2.0d) + ((Double) wrappedArray.apply(2)).doubleValue());
        }
    }

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/IBDTransformer$ZFunction.class */
    public static class ZFunction extends AbstractFunction1<WrappedArray<Integer>, WrappedArray<Double>> implements Serializable {
        private IBDExpectedFrequencies expFreqs;

        public ZFunction(IBDExpectedFrequencies iBDExpectedFrequencies) {
            this.expFreqs = iBDExpectedFrequencies;
        }

        public WrappedArray<Double> apply(WrappedArray<Integer> wrappedArray) {
            double intValue = ((Integer) wrappedArray.apply(0)).intValue() + ((Integer) wrappedArray.apply(1)).intValue() + ((Integer) wrappedArray.apply(2)).intValue();
            double d = this.expFreqs.E00 * intValue;
            double d2 = this.expFreqs.E10 * intValue;
            double d3 = this.expFreqs.E20 * intValue;
            double d4 = this.expFreqs.E01 * intValue;
            double d5 = this.expFreqs.E11 * intValue;
            double d6 = this.expFreqs.E21 * intValue;
            double d7 = this.expFreqs.E02 * intValue;
            double d8 = this.expFreqs.E12 * intValue;
            double intValue2 = ((Integer) wrappedArray.apply(0)).intValue() / d;
            double intValue3 = (((Integer) wrappedArray.apply(1)).intValue() - (intValue2 * d4)) / d5;
            double intValue4 = ((((Integer) wrappedArray.apply(2)).intValue() - (intValue2 * d7)) - (intValue3 * d8)) / (this.expFreqs.E22 * intValue);
            if (intValue2 > 1.0d) {
                intValue2 = 1.0d;
                intValue3 = 0.0d;
                intValue4 = 0.0d;
            }
            if (intValue3 > 1.0d) {
                intValue3 = 1.0d;
                intValue2 = 0.0d;
                intValue4 = 0.0d;
            }
            if (intValue4 > 1.0d) {
                intValue4 = 1.0d;
                intValue2 = 0.0d;
                intValue3 = 0.0d;
            }
            if (intValue2 < 0.0d) {
                double d9 = intValue3 + intValue4;
                intValue3 /= d9;
                intValue4 /= d9;
                intValue2 = 0.0d;
            }
            if (intValue3 < 0.0d) {
                double d10 = intValue2 + intValue4;
                intValue2 /= d10;
                intValue4 /= d10;
                intValue3 = 0.0d;
            }
            if (intValue4 < 0.0d) {
                double d11 = intValue2 + intValue3;
                intValue2 /= d11;
                intValue3 /= d11;
                intValue4 = 0.0d;
            }
            Double[] dArr = new Double[wrappedArray.size()];
            dArr[0] = Double.valueOf(intValue2);
            dArr[1] = Double.valueOf(intValue3);
            dArr[2] = Double.valueOf(intValue4);
            return WrappedArray.make(dArr);
        }
    }

    public IBDTransformer() {
        this(null);
    }

    public IBDTransformer(String str) {
        super(str);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public IBDTransformer setSamples(List<String> list) {
        super.setSamples(list);
        return this;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public IBDTransformer setSamples(String... strArr) {
        super.setSamples(strArr);
        return this;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public IBDTransformer setSkipReference(boolean z) {
        super.setSkipReference(z);
        return this;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public IBDTransformer setSkipMultiAllelic(boolean z) {
        super.setSkipMultiAllelic(z);
        return this;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public IBDTransformer setNumPairs(int i) {
        super.setNumPairs(i);
        return this;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        return super.transform(dataset).withColumn("IBD", functions.udf(new ZFunction(computeExpectedFrequencies(dataset)), DataTypes.createArrayType(DataTypes.DoubleType)).apply(new Column[]{functions.col("counts")})).withColumn("PI_HAT", functions.udf(new PiHatFunction(), DataTypes.DoubleType).apply(new Column[]{functions.col("IBD")}));
    }

    private IBDExpectedFrequencies computeExpectedFrequencies(Dataset<Row> dataset) {
        IBDExpectedFrequencies iBDExpectedFrequencies = (IBDExpectedFrequencies) dataset.map(row -> {
            IBDExpectedFrequencies iBDExpectedFrequencies2 = new IBDExpectedFrequencies();
            iBDExpectedFrequencies2.update(new RowToVariantConverter().convert(row));
            return iBDExpectedFrequencies2;
        }, Encoders.bean(IBDExpectedFrequencies.class)).reduce((iBDExpectedFrequencies2, iBDExpectedFrequencies3) -> {
            return new IBDExpectedFrequencies().setE00(iBDExpectedFrequencies2.E00 + iBDExpectedFrequencies3.E00).setE10(iBDExpectedFrequencies2.E10 + iBDExpectedFrequencies3.E10).setE20(iBDExpectedFrequencies2.E20 + iBDExpectedFrequencies3.E20).setE01(iBDExpectedFrequencies2.E01 + iBDExpectedFrequencies3.E01).setE11(iBDExpectedFrequencies2.E11 + iBDExpectedFrequencies3.E11).setE21(iBDExpectedFrequencies2.E21 + iBDExpectedFrequencies3.E21).setE02(iBDExpectedFrequencies2.E02 + iBDExpectedFrequencies3.E02).setE12(iBDExpectedFrequencies2.E12 + iBDExpectedFrequencies3.E12).setE22(iBDExpectedFrequencies2.E22 + iBDExpectedFrequencies3.E22).setCounter(iBDExpectedFrequencies2.getCounter() + iBDExpectedFrequencies3.getCounter());
        });
        iBDExpectedFrequencies.done();
        return iBDExpectedFrequencies;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer, org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        List list = (List) Arrays.stream(IBSTransformer.RETURN_SCHEMA_TYPE.fields()).collect(Collectors.toList());
        list.add(DataTypes.createStructField("IBD", DataTypes.createArrayType(DataTypes.DoubleType, false), false));
        list.add(DataTypes.createStructField("PI_HAT", DataTypes.DoubleType, false));
        return DataTypes.createStructType(list);
    }

    @Override // org.opencb.oskar.spark.variant.transformers.IBSTransformer
    public /* bridge */ /* synthetic */ IBSTransformer setSamples(List list) {
        return setSamples((List<String>) list);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1467429346:
                if (implMethodName.equals("lambda$computeExpectedFrequencies$a32ae801$1")) {
                    z = false;
                    break;
                }
                break;
            case 65062486:
                if (implMethodName.equals("lambda$computeExpectedFrequencies$c666a715$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case SampleVariantStatsTransformer.BufferUtils.SAMPLE_INDEX /* 0 */:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/opencb/oskar/spark/variant/transformers/IBDTransformer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/sql/Row;)Lorg/opencb/biodata/models/variant/stats/IBDExpectedFrequencies;")) {
                    return row -> {
                        IBDExpectedFrequencies iBDExpectedFrequencies2 = new IBDExpectedFrequencies();
                        iBDExpectedFrequencies2.update(new RowToVariantConverter().convert(row));
                        return iBDExpectedFrequencies2;
                    };
                }
                break;
            case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/opencb/oskar/spark/variant/transformers/IBDTransformer") && serializedLambda.getImplMethodSignature().equals("(Lorg/opencb/biodata/models/variant/stats/IBDExpectedFrequencies;Lorg/opencb/biodata/models/variant/stats/IBDExpectedFrequencies;)Lorg/opencb/biodata/models/variant/stats/IBDExpectedFrequencies;")) {
                    return (iBDExpectedFrequencies2, iBDExpectedFrequencies3) -> {
                        return new IBDExpectedFrequencies().setE00(iBDExpectedFrequencies2.E00 + iBDExpectedFrequencies3.E00).setE10(iBDExpectedFrequencies2.E10 + iBDExpectedFrequencies3.E10).setE20(iBDExpectedFrequencies2.E20 + iBDExpectedFrequencies3.E20).setE01(iBDExpectedFrequencies2.E01 + iBDExpectedFrequencies3.E01).setE11(iBDExpectedFrequencies2.E11 + iBDExpectedFrequencies3.E11).setE21(iBDExpectedFrequencies2.E21 + iBDExpectedFrequencies3.E21).setE02(iBDExpectedFrequencies2.E02 + iBDExpectedFrequencies3.E02).setE12(iBDExpectedFrequencies2.E12 + iBDExpectedFrequencies3.E12).setE22(iBDExpectedFrequencies2.E22 + iBDExpectedFrequencies3.E22).setCounter(iBDExpectedFrequencies2.getCounter() + iBDExpectedFrequencies3.getCounter());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
