package de.javagl.tsne;

import java.util.ArrayList;
import java.util.Locale;
import java.util.Random;
import java.util.function.Consumer;
import java.util.function.DoubleConsumer;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

/* loaded from: input_file:de/javagl/tsne/BHTSne.class */
class BHTSne {
    private static final boolean PARALLEL = true;
    private static final Logger logger = Logger.getLogger(BHTSne.class.getName());
    private static final Distance distance = new EuclideanDistance();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/javagl/tsne/BHTSne$SymResult.class */
    public static class SymResult {
        int[] sym_row_P;
        int[] sym_col_P;
        double[] sym_val_P;

        private SymResult(int[] iArr, int[] iArr2, double[] dArr) {
            this.sym_row_P = iArr;
            this.sym_col_P = iArr2;
            this.sym_val_P = dArr;
        }
    }

    BHTSne() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[][] tsne(TSneConfiguration tSneConfiguration, long j, DoubleConsumer doubleConsumer, Consumer<? super String> consumer) {
        return run(tSneConfiguration, j, doubleConsumer, consumer);
    }

    private static double[] flatten(double[][] dArr) {
        int length = dArr[0].length;
        double[] dArr2 = new double[dArr.length * dArr[0].length];
        for (int i = 0; i < dArr.length; i += PARALLEL) {
            for (int i2 = 0; i2 < dArr[i].length; i2 += PARALLEL) {
                dArr2[(i * length) + i2] = dArr[i][i2];
            }
        }
        return dArr2;
    }

    private static double[][] expand(double[] dArr, int i, int i2) {
        double[][] dArr2 = new double[i][i2];
        for (int i3 = 0; i3 < i; i3 += PARALLEL) {
            for (int i4 = 0; i4 < i2; i4 += PARALLEL) {
                dArr2[i3][i4] = dArr[(i3 * i2) + i4];
            }
        }
        return dArr2;
    }

    private static double sign_tsne(double d) {
        if (d == 0.0d) {
            return 0.0d;
        }
        return d < 0.0d ? -1.0d : 1.0d;
    }

    private static double[][] run(TSneConfiguration tSneConfiguration, long j, DoubleConsumer doubleConsumer, Consumer<? super String> consumer) {
        int xStartDim = tSneConfiguration.getXStartDim();
        double[][] xin = tSneConfiguration.getXin();
        if (tSneConfiguration.getTheta() == 0.0d) {
            throw new IllegalArgumentException("The Barnes Hut implementation does not support exact inference yet (theta==0.0), if you want exact t-SNE please use one of the standard t-SNE implementations (FastTSne for instance)");
        }
        if (tSneConfiguration.usePca() && xStartDim > tSneConfiguration.getInitialDims() && tSneConfiguration.getInitialDims() > 0) {
            xin = new PrincipalComponentAnalysis().pca(xin, tSneConfiguration.getInitialDims());
            xStartDim = tSneConfiguration.getInitialDims();
            printLog(consumer, "X:Shape after PCA is = " + xin.length + " x " + xin[0].length, new Object[0]);
        }
        double[] flatten = flatten(xin);
        int nrRows = tSneConfiguration.getNrRows();
        int outputDims = tSneConfiguration.getOutputDims();
        double[] dArr = new double[nrRows * outputDims];
        printLog(consumer, "X:Shape is = " + nrRows + " x " + xStartDim, new Object[0]);
        double perplexity = tSneConfiguration.getPerplexity();
        if (nrRows - PARALLEL < 3.0d * perplexity) {
            throw new IllegalArgumentException("Perplexity too large for the number of data points! The number of data points should be more than three times the perplexity. There are " + nrRows + " data points, and the perplexity is " + perplexity + ".");
        }
        printLog(consumer, "Using no_dims = %d, perplexity = %f, and theta = %f", Integer.valueOf(outputDims), Double.valueOf(perplexity), Double.valueOf(tSneConfiguration.getTheta()));
        Random random = new Random(j);
        double d = 0.0d;
        double d2 = 0.5d;
        double[] dArr2 = new double[nrRows * outputDims];
        double[] dArr3 = new double[nrRows * outputDims];
        double[] dArr4 = new double[nrRows * outputDims];
        for (int i = 0; i < nrRows * outputDims; i += PARALLEL) {
            dArr4[i] = 1.0d;
        }
        printLog(consumer, "Computing input similarities...", new Object[0]);
        long currentTimeMillis = System.currentTimeMillis();
        double d3 = 0.0d;
        for (int i2 = 0; i2 < nrRows * xStartDim; i2 += PARALLEL) {
            if (flatten[i2] > d3) {
                d3 = flatten[i2];
            }
        }
        for (int i3 = 0; i3 < nrRows * xStartDim; i3 += PARALLEL) {
            int i4 = i3;
            flatten[i4] = flatten[i4] / d3;
        }
        int i5 = (int) (3.0d * perplexity);
        int[] iArr = new int[nrRows + PARALLEL];
        int[] iArr2 = new int[nrRows * i5];
        double[] dArr5 = new double[nrRows * i5];
        computeGaussianPerplexity(flatten, nrRows, xStartDim, iArr, iArr2, dArr5, perplexity, i5, random, consumer);
        if (Thread.currentThread().isInterrupted()) {
            return (double[][]) null;
        }
        SymResult symmetrizeMatrix = symmetrizeMatrix(iArr, iArr2, dArr5, nrRows);
        int[] iArr3 = symmetrizeMatrix.sym_row_P;
        int[] iArr4 = symmetrizeMatrix.sym_col_P;
        double[] dArr6 = symmetrizeMatrix.sym_val_P;
        double d4 = 0.0d;
        for (int i6 = 0; i6 < iArr3[nrRows]; i6 += PARALLEL) {
            d4 += dArr6[i6];
        }
        for (int i7 = 0; i7 < iArr3[nrRows]; i7 += PARALLEL) {
            int i8 = i7;
            dArr6[i8] = dArr6[i8] / d4;
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        for (int i9 = 0; i9 < iArr3[nrRows]; i9 += PARALLEL) {
            int i10 = i9;
            dArr6[i10] = dArr6[i10] * 12.0d;
        }
        for (int i11 = 0; i11 < nrRows * outputDims; i11 += PARALLEL) {
            dArr[i11] = random.nextDouble() * 1.0E-4d;
        }
        printLog(consumer, "Done in %4.2f seconds (sparsity = %f)!\nLearning embedding...", Double.valueOf((currentTimeMillis2 - currentTimeMillis) / 1000.0d), Double.valueOf(iArr3[nrRows] / (nrRows * nrRows)));
        long currentTimeMillis3 = System.currentTimeMillis();
        for (int i12 = 0; i12 < tSneConfiguration.getMaxIter(); i12 += PARALLEL) {
            if (Thread.currentThread().isInterrupted()) {
                return (double[][]) null;
            }
            computeGradient(null, iArr3, iArr4, dArr6, dArr, nrRows, outputDims, dArr2, tSneConfiguration.getTheta());
            updateGradient(nrRows, outputDims, dArr, d2, 200.0d, dArr2, dArr3, dArr4);
            zeroMean(dArr, nrRows, outputDims);
            if (i12 == 250) {
                for (int i13 = 0; i13 < iArr3[nrRows]; i13 += PARALLEL) {
                    int i14 = i13;
                    dArr6[i14] = dArr6[i14] / 12.0d;
                }
            }
            if (i12 == 250) {
                d2 = 0.8d;
            }
            if (doubleConsumer != null) {
                doubleConsumer.accept(i12 / (tSneConfiguration.getMaxIter() - PARALLEL));
            }
            if (((i12 > 0 && i12 % 50 == 0) || i12 == tSneConfiguration.getMaxIter() - PARALLEL) && !tSneConfiguration.silent()) {
                long currentTimeMillis4 = System.currentTimeMillis();
                String str = tSneConfiguration.printError() ? "" + evaluateError(iArr3, iArr4, dArr6, dArr, nrRows, outputDims, tSneConfiguration.getTheta()) : "not_calculated";
                if (i12 == 0) {
                    printLog(consumer, "Iteration %d: error is %s", Integer.valueOf(i12 + PARALLEL), str);
                } else {
                    d += (currentTimeMillis4 - currentTimeMillis3) / 1000.0d;
                    printLog(consumer, "Iteration %d: error is %s (50 iterations in %4.2f seconds)", Integer.valueOf(i12), str, Double.valueOf((currentTimeMillis4 - currentTimeMillis3) / 1000.0d));
                }
                currentTimeMillis3 = System.currentTimeMillis();
            }
        }
        printLog(consumer, "Fitting performed in %4.2f seconds.", Double.valueOf(d + ((System.currentTimeMillis() - currentTimeMillis3) / 1000.0d)));
        return expand(dArr, nrRows, outputDims);
    }

    private static void updateGradient(int i, int i2, double[] dArr, double d, double d2, double[] dArr2, double[] dArr3, double[] dArr4) {
        for (int i3 = 0; i3 < i * i2; i3 += PARALLEL) {
            dArr4[i3] = sign_tsne(dArr2[i3]) != sign_tsne(dArr3[i3]) ? dArr4[i3] + 0.2d : dArr4[i3] * 0.8d;
            if (dArr4[i3] < 0.01d) {
                dArr4[i3] = 0.01d;
            }
            dArr[i3] = dArr[i3] + dArr3[i3];
            dArr3[i3] = (d * dArr3[i3]) - ((d2 * dArr4[i3]) * dArr2[i3]);
        }
    }

    private static void computeGradient(double[] dArr, int[] iArr, int[] iArr2, double[] dArr2, double[] dArr3, int i, int i2, double[] dArr4, double d) {
        SPTree sPTree = new SPTree(i2, dArr3, i);
        double[] dArr5 = new double[i];
        double[] dArr6 = new double[i * i2];
        double[][] dArr7 = new double[i][i2];
        double[][] dArr8 = new double[i][i2];
        sPTree.computeEdgeForces(iArr, iArr2, dArr2, i, dArr6);
        IntStream.range(0, i).parallel().forEach(i3 -> {
            sPTree.computeNonEdgeForces(i3, d, dArr7[i3], dArr8[i3], dArr5);
        });
        double sum = DoubleStream.of(dArr5).sum();
        for (int i4 = 0; i4 < i; i4 += PARALLEL) {
            for (int i5 = 0; i5 < i2; i5 += PARALLEL) {
                dArr4[(i4 * i2) + i5] = dArr6[(i4 * i2) + i5] - (dArr7[i4][i5] / sum);
            }
        }
    }

    private static double evaluateError(int[] iArr, int[] iArr2, double[] dArr, double[] dArr2, int i, int i2, double d) {
        SPTree sPTree = new SPTree(i2, dArr2, i);
        double[] dArr3 = new double[i2];
        double[][] dArr4 = new double[i][i2];
        double[] dArr5 = new double[i];
        IntStream.range(0, i).parallel().forEach(i3 -> {
            sPTree.computeNonEdgeForces(i3, d, dArr3, dArr4[i3], dArr5);
        });
        double sum = DoubleStream.of(dArr5).sum();
        double d2 = 0.0d;
        for (int i4 = 0; i4 < i; i4 += PARALLEL) {
            int i5 = i4 * i2;
            for (int i6 = iArr[i4]; i6 < iArr[i4 + PARALLEL]; i6 += PARALLEL) {
                double d3 = 0.0d;
                int i7 = iArr2[i6] * i2;
                for (int i8 = 0; i8 < i2; i8 += PARALLEL) {
                    dArr3[i8] = dArr2[i5 + i8];
                }
                for (int i9 = 0; i9 < i2; i9 += PARALLEL) {
                    int i10 = i9;
                    dArr3[i10] = dArr3[i10] - dArr2[i7 + i9];
                }
                for (int i11 = 0; i11 < i2; i11 += PARALLEL) {
                    d3 += dArr3[i11] * dArr3[i11];
                }
                d2 += dArr[i6] * Math.log((dArr[i6] + Double.MIN_VALUE) / (((1.0d / (1.0d + d3)) / sum) + Double.MIN_VALUE));
            }
        }
        return d2;
    }

    private static void computeGaussianPerplexity(double[] dArr, int i, int i2, int[] iArr, int[] iArr2, double[] dArr2, double d, int i3, Random random, Consumer<? super String> consumer) {
        if (d > i3) {
            logger.warning("Perplexity should be lower than K!");
            logger.warning("perplexity=" + d + ", K=" + i3);
        }
        double[] dArr3 = new double[i - PARALLEL];
        iArr[0] = 0;
        for (int i4 = 0; i4 < i; i4 += PARALLEL) {
            iArr[i4 + PARALLEL] = iArr[i4] + i3;
        }
        VpTree vpTree = new VpTree(distance, random);
        DataPoint[] dataPointArr = new DataPoint[i];
        for (int i5 = 0; i5 < i; i5 += PARALLEL) {
            dataPointArr[i5] = new DataPoint(i2, i5, MatrixOps.extractRowViewFromFlatMatrix(dArr, i5, i2));
        }
        vpTree.create(dataPointArr);
        printLog(consumer, "Building tree...", new Object[0]);
        long nanoTime = System.nanoTime();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i6 = 0; i6 < i; i6 += PARALLEL) {
            if (Thread.currentThread().isInterrupted()) {
                return;
            }
            if (i6 % 10000 == 0) {
                printLog(consumer, " - point %d of %d", Integer.valueOf(i6), Integer.valueOf(i));
            }
            arrayList.clear();
            arrayList2.clear();
            vpTree.search(dataPointArr[i6], i3 + PARALLEL, arrayList, arrayList2);
            boolean z = false;
            double d2 = 1.0d;
            double d3 = -1.7976931348623157E308d;
            double d4 = Double.MAX_VALUE;
            double d5 = 0.0d;
            for (int i7 = 0; !z && i7 < 200; i7 += PARALLEL) {
                d5 = Double.MIN_VALUE;
                double d6 = 0.0d;
                for (int i8 = 0; i8 < i3; i8 += PARALLEL) {
                    dArr3[i8] = Math.exp((-d2) * arrayList2.get(i8 + PARALLEL).doubleValue());
                    d5 += dArr3[i8];
                    d6 += d2 * arrayList2.get(i8 + PARALLEL).doubleValue() * dArr3[i8];
                }
                double log = ((d6 / d5) + Math.log(d5)) - Math.log(d);
                if (log < 1.0E-5d && (-log) < 1.0E-5d) {
                    z = PARALLEL;
                } else if (log > 0.0d) {
                    d3 = d2;
                    d2 = (d4 == Double.MAX_VALUE || d4 == -1.7976931348623157E308d) ? d2 * 2.0d : (d2 + d4) / 2.0d;
                } else {
                    d4 = d2;
                    d2 = (d3 == -1.7976931348623157E308d || d3 == Double.MAX_VALUE) ? d2 / 2.0d : (d2 + d3) / 2.0d;
                }
            }
            for (int i9 = 0; i9 < i3; i9 += PARALLEL) {
                int i10 = i9;
                dArr3[i10] = dArr3[i10] / d5;
                iArr2[iArr[i6] + i9] = arrayList.get(i9 + PARALLEL).index();
                dArr2[iArr[i6] + i9] = dArr3[i9];
            }
        }
        printLog(consumer, "Building tree took %f ms", Double.valueOf((System.nanoTime() - nanoTime) / 1000000.0d));
    }

    private static SymResult symmetrizeMatrix(int[] iArr, int[] iArr2, double[] dArr, int i) {
        int[] iArr3 = new int[i];
        for (int i2 = 0; i2 < i; i2 += PARALLEL) {
            for (int i3 = iArr[i2]; i3 < iArr[i2 + PARALLEL]; i3 += PARALLEL) {
                boolean z = false;
                for (int i4 = iArr[iArr2[i3]]; i4 < iArr[iArr2[i3] + PARALLEL]; i4 += PARALLEL) {
                    if (iArr2[i4] == i2) {
                        z = PARALLEL;
                    }
                }
                if (z) {
                    int i5 = i2;
                    iArr3[i5] = iArr3[i5] + PARALLEL;
                } else {
                    int i6 = i2;
                    iArr3[i6] = iArr3[i6] + PARALLEL;
                    int i7 = iArr2[i3];
                    iArr3[i7] = iArr3[i7] + PARALLEL;
                }
            }
        }
        int i8 = 0;
        for (int i9 = 0; i9 < i; i9 += PARALLEL) {
            i8 += iArr3[i9];
        }
        int[] iArr4 = new int[i + PARALLEL];
        int[] iArr5 = new int[i8];
        double[] dArr2 = new double[i8];
        iArr4[0] = 0;
        for (int i10 = 0; i10 < i; i10 += PARALLEL) {
            iArr4[i10 + PARALLEL] = iArr4[i10] + iArr3[i10];
        }
        int[] iArr6 = new int[i];
        for (int i11 = 0; i11 < i; i11 += PARALLEL) {
            for (int i12 = iArr[i11]; i12 < iArr[i11 + PARALLEL]; i12 += PARALLEL) {
                boolean z2 = false;
                for (int i13 = iArr[iArr2[i12]]; i13 < iArr[iArr2[i12] + PARALLEL]; i13 += PARALLEL) {
                    if (iArr2[i13] == i11) {
                        z2 = PARALLEL;
                        if (i11 <= iArr2[i12]) {
                            iArr5[iArr4[i11] + iArr6[i11]] = iArr2[i12];
                            iArr5[iArr4[iArr2[i12]] + iArr6[iArr2[i12]]] = i11;
                            dArr2[iArr4[i11] + iArr6[i11]] = dArr[i12] + dArr[i13];
                            dArr2[iArr4[iArr2[i12]] + iArr6[iArr2[i12]]] = dArr[i12] + dArr[i13];
                        }
                    }
                }
                if (!z2) {
                    iArr5[iArr4[i11] + iArr6[i11]] = iArr2[i12];
                    iArr5[iArr4[iArr2[i12]] + iArr6[iArr2[i12]]] = i11;
                    dArr2[iArr4[i11] + iArr6[i11]] = dArr[i12];
                    dArr2[iArr4[iArr2[i12]] + iArr6[iArr2[i12]]] = dArr[i12];
                }
                if (!z2 || (z2 && i11 <= iArr2[i12])) {
                    int i14 = i11;
                    iArr6[i14] = iArr6[i14] + PARALLEL;
                    if (iArr2[i12] != i11) {
                        int i15 = iArr2[i12];
                        iArr6[i15] = iArr6[i15] + PARALLEL;
                    }
                }
            }
        }
        for (int i16 = 0; i16 < i8; i16 += PARALLEL) {
            int i17 = i16;
            dArr2[i17] = dArr2[i17] / 2.0d;
        }
        return new SymResult(iArr4, iArr5, dArr2);
    }

    private static void zeroMean(double[] dArr, int i, int i2) {
        double[] dArr2 = new double[i2];
        for (int i3 = 0; i3 < i; i3 += PARALLEL) {
            for (int i4 = 0; i4 < i2; i4 += PARALLEL) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + dArr[(i3 * i2) + i4];
            }
        }
        for (int i6 = 0; i6 < i2; i6 += PARALLEL) {
            int i7 = i6;
            dArr2[i7] = dArr2[i7] / i;
        }
        for (int i8 = 0; i8 < i; i8 += PARALLEL) {
            for (int i9 = 0; i9 < i2; i9 += PARALLEL) {
                int i10 = (i8 * i2) + i9;
                dArr[i10] = dArr[i10] - dArr2[i9];
            }
        }
    }

    private static void printLog(Consumer<? super String> consumer, String str, Object... objArr) {
        if (consumer != null) {
            consumer.accept(String.format(Locale.ENGLISH, str, objArr));
        }
    }
}
