package org.opencb.oskar.spark.variant.transformers;

import java.io.Serializable;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.ml.param.Param;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MetadataBuilder;
import org.apache.spark.sql.types.StructType;
import org.opencb.commons.utils.ListUtils;
import org.opencb.oskar.spark.variant.converters.DataframeToFacetFieldConverter;
import org.opencb.oskar.spark.variant.transformers.SampleVariantStatsTransformer;
import org.opencb.oskar.spark.variant.udf.VariantUdfManager;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.WrappedArray;
import scala.runtime.AbstractFunction1;

/* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/FacetTransformer.class */
public class FacetTransformer extends AbstractTransformer {
    public static final String SEPARATOR = "__";
    public static final String POPFREQ_PREFIX = "popFreq__";
    public static final String STATS_PREFIX = "stats__";
    private Param<String> facetParam;
    private Map<String, String> validCategoricalFields;
    private Map<String, String> validRangeFields;
    private Set<String> isExplode;
    private DataframeToFacetFieldConverter converter;

    /* renamed from: org.opencb.oskar.spark.variant.transformers.FacetTransformer$1, reason: invalid class name */
    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/FacetTransformer$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType = new int[DataframeToFacetFieldConverter.FacetType.values().length];

        static {
            try {
                $SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType[DataframeToFacetFieldConverter.FacetType.CATEGORICAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType[DataframeToFacetFieldConverter.FacetType.RANGE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType[DataframeToFacetFieldConverter.FacetType.AGGREGATION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/opencb/oskar/spark/variant/transformers/FacetTransformer$ScoreFunction.class */
    public static class ScoreFunction extends AbstractFunction1<WrappedArray<GenericRowWithSchema>, Double> implements Serializable {
        private String source;

        public ScoreFunction(String str) {
            this.source = str;
        }

        public Double apply(WrappedArray<GenericRowWithSchema> wrappedArray) {
            if (wrappedArray == null) {
                return Double.valueOf(Double.NEGATIVE_INFINITY);
            }
            for (int i = 0; i < wrappedArray.length(); i++) {
                Row row = (Row) wrappedArray.apply(i);
                if (row.apply(1).equals(this.source)) {
                    return Double.valueOf(Double.parseDouble(row.apply(0).toString()));
                }
            }
            return Double.valueOf(Double.NEGATIVE_INFINITY);
        }
    }

    public FacetTransformer() {
        this(null);
    }

    public FacetTransformer(String str) {
        super(str);
        this.converter = new DataframeToFacetFieldConverter();
        this.facetParam = new Param<>(this, "facet", "");
        setDefault(facetParam(), null);
        init();
    }

    public Param<String> facetParam() {
        return this.facetParam;
    }

    public FacetTransformer setFacet(String str) {
        set(facetParam(), str);
        return this;
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset withColumn;
        String str = (String) getOrDefault(facetParam());
        if (StringUtils.isEmpty(str)) {
            return dataset.sparkSession().emptyDataFrame();
        }
        DataframeToFacetFieldConverter dataframeToFacetFieldConverter = this.converter;
        if (str.contains(DataframeToFacetFieldConverter.NESTED_FACET_SEPARATOR)) {
            DataframeToFacetFieldConverter dataframeToFacetFieldConverter2 = this.converter;
            String[] split = str.split(DataframeToFacetFieldConverter.NESTED_FACET_SEPARATOR);
            for (int i = 0; i < split.length - 1; i++) {
                DataframeToFacetFieldConverter dataframeToFacetFieldConverter3 = this.converter;
                if (DataframeToFacetFieldConverter.getFacetType(split[i]) == DataframeToFacetFieldConverter.FacetType.AGGREGATION) {
                    throw new InvalidParameterException("In nested facets, aggregations must be in last place: " + str);
                }
            }
            LinkedList linkedList = new LinkedList();
            LinkedList linkedList2 = new LinkedList();
            for (String str2 : split) {
                DataframeToFacetFieldConverter dataframeToFacetFieldConverter4 = this.converter;
                String fieldName = DataframeToFacetFieldConverter.getFieldName(str2);
                if (linkedList.contains(fieldName)) {
                    throw new InvalidParameterException("In nested facets, repeating facets are not allowed: " + str);
                }
                linkedList.add(fieldName);
            }
            boolean z = false;
            Dataset<?> dataset2 = dataset;
            for (int i2 = 0; i2 < linkedList.size(); i2++) {
                switch (AnonymousClass1.$SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType[DataframeToFacetFieldConverter.getFacetType(split[i2]).ordinal()]) {
                    case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                        String str3 = (String) linkedList.get(i2);
                        dataset2 = processCategoricalFacet(split[i2], (String) linkedList.get(i2), str3, dataset2);
                        linkedList2.add(str3);
                        break;
                    case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                        String str4 = ((String) linkedList.get(i2)) + "Range";
                        dataset2 = processRangeFacet(split[i2], (String) linkedList.get(i2), str4, dataset2);
                        linkedList2.add(str4);
                        break;
                    case SampleVariantStatsTransformer.BufferUtils.TYPE_COUNT_INDEX /* 3 */:
                        dataset2 = processAggregationFacet(split[i2], (String) linkedList.get(i2), dataset2);
                        z = true;
                        break;
                    default:
                        throw new InvalidParameterException("In nested facets, unknown facet in middle position: " + str);
                }
            }
            Column[] columnArr = new Column[linkedList2.size()];
            for (int i3 = 0; i3 < linkedList2.size(); i3++) {
                columnArr[i3] = new Column((String) linkedList2.get(i3));
            }
            if (z) {
                int length = split.length - 1;
                String substring = split[length].substring(0, split[length].indexOf("("));
                if (((String) linkedList.get(length)).startsWith(STATS_PREFIX)) {
                    String replace = ((String) linkedList.get(length)).replace(":", DataframeToFacetFieldConverter.LABEL_SEPARATOR).replace("@", "____");
                    withColumn = dataset2.withColumnRenamed((String) linkedList.get(length), replace).groupBy(columnArr).agg(getAggregationExpr(substring, replace).as((String) linkedList.get(length)), new Column[]{functions.count(functions.lit(1)).as("count")}).orderBy(columnArr);
                } else {
                    withColumn = dataset2.groupBy(columnArr).agg(getAggregationExpr(substring, (String) linkedList.get(length)), new Column[]{functions.count(functions.lit(1)).as("count")}).orderBy(columnArr);
                }
            } else {
                withColumn = dataset2.groupBy(columnArr).count().orderBy(columnArr);
            }
        } else {
            String fieldName2 = DataframeToFacetFieldConverter.getFieldName(str);
            switch (AnonymousClass1.$SwitchMap$org$opencb$oskar$spark$variant$converters$DataframeToFacetFieldConverter$FacetType[DataframeToFacetFieldConverter.getFacetType(str).ordinal()]) {
                case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                    withColumn = processCategoricalFacet(str, fieldName2, fieldName2, dataset).groupBy(fieldName2, new String[0]).count().orderBy(fieldName2, new String[0]);
                    break;
                case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                    String str5 = fieldName2 + "Range";
                    withColumn = processRangeFacet(str, fieldName2, str5, dataset).groupBy(str5, new String[0]).count().orderBy(str5, new String[0]);
                    break;
                case SampleVariantStatsTransformer.BufferUtils.TYPE_COUNT_INDEX /* 3 */:
                    Dataset<Row> processAggregationFacet = processAggregationFacet(str, fieldName2, dataset);
                    long count = processAggregationFacet.count();
                    String substring2 = str.substring(0, str.indexOf("("));
                    if (fieldName2.startsWith(STATS_PREFIX)) {
                        String replace2 = fieldName2.replace(":", DataframeToFacetFieldConverter.LABEL_SEPARATOR).replace("@", "____");
                        withColumn = processAggregationFacet.withColumnRenamed(fieldName2, replace2).agg(getAggregationExpr(substring2, replace2), new Column[0]).withColumn("count", functions.lit(Long.valueOf(count))).withColumnRenamed(substring2 + "(" + replace2 + ")", substring2 + "(" + fieldName2 + ")");
                        break;
                    } else {
                        withColumn = processAggregationFacet.agg(getAggregationExpr(substring2, fieldName2), new Column[0]).withColumn("count", functions.lit(Long.valueOf(count)));
                        break;
                    }
                default:
                    throw new InvalidParameterException("Unknown facet in middle position: " + str);
            }
        }
        return withColumn.withColumn("count", functions.col("count").as("count", new MetadataBuilder().putString("facet", str).build()));
    }

    private Column getAggregationExpr(String str, String str2) {
        String str3 = str + "(" + str2 + ")";
        return str.equals("sumsq") ? functions.expr("sum(power(" + str2 + ", 2))").as(str3) : str.equals("percentile") ? functions.expr("percentile(" + str2 + ", array(" + DataframeToFacetFieldConverter.PERCENTILE_PARAMS + "))").as(str3) : str.equals("unique") ? functions.expr("collect_set(" + str2 + ")").as(str3) : functions.expr(str3);
    }

    private Dataset<Row> processCategoricalFacet(String str, String str2, String str3, Dataset<Row> dataset) {
        Dataset<Row> dataset2 = dataset;
        if (isValidField(str2)) {
            if (isNumeric(str2)) {
                if (str2.startsWith(POPFREQ_PREFIX)) {
                    String[] split = str2.split(SEPARATOR);
                    dataset2 = dataset2.withColumn(str3, VariantUdfManager.population_frequency("annotation", split[1], split[2]));
                } else if (str2.startsWith(STATS_PREFIX)) {
                    String[] split2 = str2.split(SEPARATOR);
                    dataset2 = dataset2.withColumn("tmp", VariantUdfManager.study("studies", split2[1])).withColumn(str3, functions.col("tmp.stats." + split2[2] + ".altAlleleFreq"));
                } else {
                    UserDefinedFunction udf = functions.udf(new ScoreFunction(str2), DataTypes.DoubleType);
                    if (isSubstitutionScore(str2)) {
                        dataset2 = dataset2.withColumn("tmp1", getColumn(str2));
                    }
                    dataset2 = dataset2.withColumn(str3, udf.apply(createFunctScoreSeq(str2, "tmp1")));
                }
            } else if (this.isExplode.contains(str2)) {
                dataset2 = dataset2.withColumn(str3, getColumn(str2));
            }
            List<String> includeValues = getIncludeValues(str);
            if (ListUtils.isNotEmpty(includeValues)) {
                StringBuilder append = new StringBuilder(str3).append("='").append(includeValues.get(0)).append("'");
                for (int i = 1; i < includeValues.size(); i++) {
                    append.append(" OR ").append(str3).append("='").append(includeValues.get(i)).append("'");
                }
                dataset2 = dataset2.filter(append.toString());
            }
        }
        return dataset2;
    }

    private Dataset<Row> processRangeFacet(String str, String str2, String str3, Dataset<Row> dataset) {
        Column apply;
        String[] split = str.substring(str.indexOf("[") + 1).replace("[", ":").replace("..", ":").replace("]", "").split(":");
        double parseDouble = Double.parseDouble(split[0]);
        double parseDouble2 = Double.parseDouble(split[1]);
        double parseDouble3 = Double.parseDouble(split[2]);
        String str4 = null;
        String str5 = str3;
        if (str.startsWith(POPFREQ_PREFIX)) {
            String[] split2 = str2.split(SEPARATOR);
            apply = VariantUdfManager.population_frequency("annotation", split2[1], split2[2]);
        } else if (str.startsWith(STATS_PREFIX)) {
            str4 = str3;
            str5 = str3.replace(":", DataframeToFacetFieldConverter.LABEL_SEPARATOR).replace("@", "____");
            String[] split3 = str2.split(SEPARATOR);
            dataset = dataset.withColumn("tmp", VariantUdfManager.study("studies", split3[1]));
            apply = functions.col("tmp.stats." + split3[2] + ".altAlleleFreq");
        } else {
            UserDefinedFunction udf = functions.udf(new ScoreFunction(str2), DataTypes.DoubleType);
            if (isSubstitutionScore(str2)) {
                dataset = dataset.withColumn("tmp1", getColumn(str2));
            }
            apply = udf.apply(createFunctScoreSeq(str2, "tmp1"));
        }
        Dataset<Row> filter = dataset.withColumn(str5, apply.divide(Double.valueOf(parseDouble3)).cast(DataTypes.IntegerType).multiply(Double.valueOf(parseDouble3))).filter(str5 + ">= " + parseDouble + " AND " + str5 + " <= " + parseDouble2);
        if (str4 != null) {
            filter = filter.withColumnRenamed(str5, str4);
        }
        return filter;
    }

    private Dataset<Row> processAggregationFacet(String str, String str2, Dataset<Row> dataset) {
        String substring = str.substring(0, str.indexOf("("));
        boolean z = false;
        String[] strArr = DataframeToFacetFieldConverter.AGGREGATION_FUNCTIONS;
        int length = strArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (strArr[i].equals(substring)) {
                z = true;
                break;
            }
            i++;
        }
        if (!z) {
            throw new InvalidParameterException("Aggregation function unknown: " + substring);
        }
        if (isNumeric(str2)) {
            if (this.validRangeFields.containsKey(str2)) {
                UserDefinedFunction udf = functions.udf(new ScoreFunction(str2), DataTypes.DoubleType);
                if (isSubstitutionScore(str2)) {
                    dataset = dataset.withColumn("tmp1", getColumn(str2));
                }
                return dataset.withColumn(str2, udf.apply(createFunctScoreSeq(str2, "tmp1")));
            }
            if (str2.startsWith(POPFREQ_PREFIX)) {
                String[] split = str2.split(SEPARATOR);
                return dataset.withColumn(str2, VariantUdfManager.population_frequency("annotation", split[1], split[2]));
            }
            if (str2.startsWith(STATS_PREFIX)) {
                String[] split2 = str2.split(SEPARATOR);
                return dataset.withColumn("tmp", VariantUdfManager.study("studies", split2[1])).withColumn(str2, functions.col("tmp.stats." + split2[2] + ".altAlleleFreq"));
            }
        }
        return dataset;
    }

    private ListBuffer<Column> createFunctScoreSeq(String str) {
        return createFunctScoreSeq(str, null);
    }

    private ListBuffer<Column> createFunctScoreSeq(String str, String str2) {
        if (isFunctionalScore(str)) {
            return new ListBuffer().$plus$eq(functions.col("annotation.functionalScore"));
        }
        if (isSubstitutionScore(str)) {
            return new ListBuffer().$plus$eq(functions.col(str2));
        }
        if (isConservationScore(str)) {
            return new ListBuffer().$plus$eq(functions.col("annotation.conservation"));
        }
        return null;
    }

    @Override // org.opencb.oskar.spark.variant.transformers.AbstractTransformer
    public StructType transformSchema(StructType structType) {
        return structType;
    }

    private void init() {
        this.validCategoricalFields = new HashMap();
        this.validCategoricalFields.put("chromosome", "chromosome");
        this.validCategoricalFields.put("type", "type");
        this.validCategoricalFields.put("studies", "studies.studyId");
        this.validCategoricalFields.put("biotype", "annotation.consequenceTypes.biotype");
        this.validCategoricalFields.put("ct", "annotation.consequenceTypes.sequenceOntologyTerms.name");
        this.validCategoricalFields.put("gene", "annotation.consequenceTypes.geneName");
        this.validCategoricalFields.put("ensemblGeneId", "annotation.consequenceTypes.ensemblGeneId");
        this.validCategoricalFields.put("ensemblTranscriptId", "annotation.consequenceTypes.ensemblTranscriptId");
        this.validCategoricalFields.put("gerp", "annotation.conservation");
        this.validCategoricalFields.put("phylop", "annotation.conservation");
        this.validCategoricalFields.put("phastCons", "annotation.conservation");
        this.validCategoricalFields.put("cadd_scaled", "annotation.functionalScore");
        this.validCategoricalFields.put("cadd_raw", "annotation.functionalScore");
        this.validCategoricalFields.put("sift", "annotation.consequenceTypes.proteinVariantAnnotation.substitutionScores");
        this.validCategoricalFields.put("polyphen", "annotation.consequenceTypes.proteinVariantAnnotation.substitutionScores");
        this.isExplode = new HashSet();
        this.isExplode.add("studies");
        this.isExplode.add("biotype");
        this.isExplode.add("gene");
        this.isExplode.add("ensemblGeneId");
        this.isExplode.add("ensemblTranscriptId");
        this.isExplode.add("ct");
        this.isExplode.add("gerp");
        this.isExplode.add("phylop");
        this.isExplode.add("phastCons");
        this.isExplode.add("cadd_scaled");
        this.isExplode.add("cadd_raw");
        this.isExplode.add("sift");
        this.isExplode.add("polyphen");
        this.validRangeFields = new HashMap();
        this.validRangeFields.put("gerp", "annotation.conservation");
        this.validRangeFields.put("phylop", "annotation.conservation");
        this.validRangeFields.put("phastCons", "annotation.conservation");
        this.validRangeFields.put("cadd_scaled", "annotation.functionalScore");
        this.validRangeFields.put("cadd_raw", "annotation.functionalScore");
        this.validRangeFields.put("sift", "annotation.consequenceTypes.proteinVariantAnnotation.substitutionScores");
        this.validRangeFields.put("polyphen", "annotation.consequenceTypes.proteinVariantAnnotation.substitutionScores");
    }

    private boolean isValidField(String str) {
        return this.validCategoricalFields.containsKey(str) || isNumeric(str);
    }

    private boolean isNumeric(String str) {
        return this.validRangeFields.containsKey(str) || str.startsWith(POPFREQ_PREFIX) || str.startsWith(STATS_PREFIX);
    }

    private boolean isFunctionalScore(String str) {
        return str.equals("cadd_scaled") || str.equals("cadd_raw");
    }

    private boolean isConservationScore(String str) {
        return str.equals("gerp") || str.equals("phylop") || str.equals("phastCons");
    }

    private boolean isSubstitutionScore(String str) {
        return str.equals("sift") || str.equals("polyphen");
    }

    private Column getColumn(String str) {
        if (!this.isExplode.contains(str)) {
            return functions.col(this.validCategoricalFields.get(str));
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -106805182:
                if (str.equals("biotype")) {
                    z = true;
                    break;
                }
                break;
            case 3185:
                if (str.equals("ct")) {
                    z = 2;
                    break;
                }
                break;
            case 3169045:
                if (str.equals("gene")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case SampleVariantStatsTransformer.BufferUtils.SAMPLE_INDEX /* 0 */:
                return functions.explode(VariantUdfManager.genes("annotation"));
            case SampleVariantStatsTransformer.BufferUtils.NUM_VARIANTS_INDEX /* 1 */:
                return functions.explode(VariantUdfManager.biotypes("annotation"));
            case SampleVariantStatsTransformer.BufferUtils.CHROMOSOME_COUNT_INDEX /* 2 */:
                return functions.explode(VariantUdfManager.consequence_types("annotation"));
            default:
                return functions.explode(functions.col(this.validCategoricalFields.get(str)));
        }
    }

    private List<String> getIncludeValues(String str) {
        if (str.contains("[")) {
            Matcher matcher = DataframeToFacetFieldConverter.CATEGORICAL_PATTERN.matcher(str);
            if (matcher.find()) {
                String replace = matcher.group(2).replace("[", "").replace("]", "");
                if (StringUtils.isNotEmpty(replace) && !replace.contains("*")) {
                    return Arrays.asList(replace.split(DataframeToFacetFieldConverter.INCLUDE_SEPARATOR));
                }
            }
        }
        return new ArrayList();
    }
}
