package com.criteo.rsvd;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.diag$;
import breeze.linalg.max$;
import breeze.linalg.svd;
import breeze.linalg.svd$reduced$;
import breeze.linalg.svd$reduced$reduced_Svd_DM_Impl$;
import breeze.storage.Zero$DoubleZero$;
import com.criteo.rsvd.ReconstructionError;
import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import com.typesafe.scalalogging.slf4j.Logger;
import com.typesafe.scalalogging.slf4j.StrictLogging;
import java.math.RoundingMode;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import scala.Array$;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Seq$;
import scala.collection.immutable.Map;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.Ordering$Double$;
import scala.math.Ordering$Long$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: ReconstructionError.scala */
/* loaded from: input_file:com/criteo/rsvd/ReconstructionError$.class */
public final class ReconstructionError$ implements StrictLogging {
    public static final ReconstructionError$ MODULE$ = null;
    private final Logger logger;

    static {
        new ReconstructionError$();
    }

    /* renamed from: logger, reason: merged with bridge method [inline-methods] */
    public Logger m23logger() {
        return this.logger;
    }

    public void com$typesafe$scalalogging$slf4j$StrictLogging$_setter_$logger_$eq(Logger logger) {
        this.logger = logger;
    }

    public SkinnyBlockMatrix timesSVD(SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, SkinnyBlockMatrix skinnyBlockMatrix3, boolean z) {
        return z ? skinnyBlockMatrix2.singleBlockMultiply((DenseMatrix) ((ImmutableNumericOps) diag$.MODULE$.apply(denseVector, diag$.MODULE$.diagDVDMImpl(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$))).$times(skinnyBlockMatrix.dot(skinnyBlockMatrix3), DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD()), false) : skinnyBlockMatrix.singleBlockMultiply((DenseMatrix) ((ImmutableNumericOps) diag$.MODULE$.apply(denseVector, diag$.MODULE$.diagDVDMImpl(ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$))).$times(skinnyBlockMatrix2.dot(skinnyBlockMatrix3), DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD()), false);
    }

    public SkinnyBlockMatrix findGoodBasisRSVDError(BlockMatrix blockMatrix, BlockMatrix blockMatrix2, SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, SkinnyBlockMatrix skinnyBlockMatrix3, int i) {
        Predef$.MODULE$.assert(skinnyBlockMatrix.numRows() == blockMatrix.matHeight());
        Predef$.MODULE$.assert(skinnyBlockMatrix2.numRows() == blockMatrix.matHeight());
        Predef$.MODULE$.assert(skinnyBlockMatrix.numCols() == skinnyBlockMatrix2.numCols());
        Predef$.MODULE$.assert(skinnyBlockMatrix.numCols() == denseVector.length());
        SingleDimensionPartitioner singleDimensionPartitioner = new SingleDimensionPartitioner(blockMatrix.numDimBlocksHeight(), blockMatrix.partitionWidthInBlocks());
        SingleDimensionPartitioner singleDimensionPartitioner2 = new SingleDimensionPartitioner(blockMatrix.numDimBlocksWidth(), blockMatrix.partitionWidthInBlocks());
        SkinnyBlockMatrix repartitionBy = skinnyBlockMatrix.repartitionBy(singleDimensionPartitioner);
        SkinnyBlockMatrix repartitionBy2 = skinnyBlockMatrix2.repartitionBy(singleDimensionPartitioner2);
        SkinnyBlockMatrix repartitionBy3 = skinnyBlockMatrix.repartitionBy(singleDimensionPartitioner2);
        SkinnyBlockMatrix repartitionBy4 = skinnyBlockMatrix2.repartitionBy(singleDimensionPartitioner);
        ObjectRef create = ObjectRef.create((SkinnyBlockMatrix) blockMatrix.skinnyMultiplyMinusSVD(skinnyBlockMatrix3, repartitionBy, denseVector, repartitionBy2, true, false).qr()._1());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).foreach$mVc$sp(new ReconstructionError$$anonfun$findGoodBasisRSVDError$1(blockMatrix, blockMatrix2, denseVector, repartitionBy, repartitionBy2, repartitionBy3, repartitionBy4, create));
        return (SkinnyBlockMatrix) ((SkinnyBlockMatrix) create.elem).qr()._1();
    }

    public DenseVector<Object> computeSingularValueRSVDResidual(BlockMatrix blockMatrix, SkinnyBlockMatrix skinnyBlockMatrix, SkinnyBlockMatrix skinnyBlockMatrix2, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix3, int i) {
        SingleDimensionPartitioner singleDimensionPartitioner = new SingleDimensionPartitioner(blockMatrix.numDimBlocksWidth(), blockMatrix.partitionWidthInBlocks());
        Tuple2<SkinnyBlockMatrix, DenseMatrix<Object>> qr = blockMatrix.skinnyMultiply(skinnyBlockMatrix, true).repartitionBy(singleDimensionPartitioner).minus(timesSVD(skinnyBlockMatrix2.repartitionBy(singleDimensionPartitioner), denseVector, skinnyBlockMatrix3.repartitionBy(singleDimensionPartitioner), skinnyBlockMatrix, false)).qr();
        if (qr == null) {
            throw new MatchError(qr);
        }
        svd.SVD svd = (svd.SVD) svd$reduced$.MODULE$.apply((DenseMatrix) qr._2(), svd$reduced$reduced_Svd_DM_Impl$.MODULE$);
        if (svd != null) {
            return (DenseVector) ((DenseVector) svd.singularValues()).apply(RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i), DenseVector$.MODULE$.canSlice());
        }
        throw new MatchError(svd);
    }

    public double computeReconstructionErrorByRSVD(BlockMatrix blockMatrix, BlockMatrix blockMatrix2, SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, RSVDConfig rSVDConfig) {
        return BoxesRunTime.unboxToDouble(max$.MODULE$.apply(computeSingularValueRSVDResidual(blockMatrix2, findGoodBasisRSVDError(blockMatrix, blockMatrix2, skinnyBlockMatrix, denseVector, skinnyBlockMatrix2, SkinnyBlockMatrix$.MODULE$.randomMatrix(blockMatrix.matHeight(), rSVDConfig.embeddingDim() + rSVDConfig.oversample(), rSVDConfig.blockSize(), rSVDConfig.partitionWidthInBlocks(), skinnyBlockMatrix.blocks().context(), rSVDConfig.seed() + 1234), rSVDConfig.powerIter()), skinnyBlockMatrix, denseVector, skinnyBlockMatrix2, rSVDConfig.embeddingDim()), max$.MODULE$.reduce_Double(DenseVector$.MODULE$.canIterateValues())));
    }

    public RDD<Tuple2<Tuple2<Object, Object>, Iterator<MatrixEntry>>> generateBlockedEntries(BlockMatrix blockMatrix, int i, long j, boolean z) {
        RDD[] rddArr = (RDD[]) Predef$.MODULE$.refArrayOps(blockMatrix.blocksRDDs()).map(new ReconstructionError$$anonfun$2(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(RDD.class)));
        RDD[] rddArr2 = (RDD[]) Predef$.MODULE$.refArrayOps(blockMatrix.blocksRDDs()).map(new ReconstructionError$$anonfun$3(i, j, z, Ints.checkedCast(LongMath.divide(blockMatrix.matWidth(), blockMatrix.blockSize(), RoundingMode.CEILING))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(RDD.class)));
        SparkContext sparkContext = blockMatrix.blocksRDDs()[0].sparkContext();
        return sparkContext.union(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new RDD[]{sparkContext.union(Predef$.MODULE$.wrapRefArray(rddArr), ClassTag$.MODULE$.apply(Tuple2.class)), sparkContext.union(Predef$.MODULE$.wrapRefArray(rddArr2), ClassTag$.MODULE$.apply(Tuple2.class))})), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public RDD<MatrixEntry> generateRandomEntries(BlockMatrix blockMatrix, int i, long j) {
        return blockMatrix.blocksRDDs()[0].sparkContext().union(Predef$.MODULE$.refArrayOps((RDD[]) Predef$.MODULE$.refArrayOps(blockMatrix.blocksRDDs()).map(new ReconstructionError$$anonfun$4(blockMatrix), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(RDD.class)))).toIterator().$plus$plus(new ReconstructionError$$anonfun$generateRandomEntries$1((RDD[]) Predef$.MODULE$.refArrayOps(blockMatrix.blocksRDDs()).map(new ReconstructionError$$anonfun$5(i, j, blockMatrix.blockSize(), blockMatrix.numDimBlocksHeight(), blockMatrix.numDimBlocksWidth(), blockMatrix.matWidth() % blockMatrix.blockSize(), blockMatrix.matHeight() % blockMatrix.blockSize()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(RDD.class))))).toSeq(), ClassTag$.MODULE$.apply(MatrixEntry.class));
    }

    public RDD<MatrixEntry> generateEntries(BlockMatrix blockMatrix, int i, long j, boolean z) {
        return generateBlockedEntries(blockMatrix, i, j, z).flatMap(new ReconstructionError$$anonfun$generateEntries$1(blockMatrix.blockSize()), ClassTag$.MODULE$.apply(MatrixEntry.class));
    }

    public Iterator<Object> computeNormPartition(Iterator<Tuple2<Tuple2<Object, Object>, Iterable<MatrixEntry>>> iterator, Iterator<Tuple2<Tuple2<Object, DenseMatrix<Object>>, Tuple2<Object, DenseMatrix<Object>>>> iterator2, DenseVector<Object> denseVector, Function2<Object, Object, Object> function2) {
        Map map = iterator.toMap(Predef$.MODULE$.$conforms());
        DoubleRef create = DoubleRef.create(0.0d);
        iterator2.foreach(new ReconstructionError$$anonfun$computeNormPartition$1(denseVector, function2, map, create));
        return package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapDoubleArray(new double[]{create.elem}));
    }

    public double computeReconstructionErrorFromEntries(SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, RDD<MatrixEntry> rdd, double d) {
        Predef$.MODULE$.require(d > 0.0d);
        Tuple2 tuple2 = (Tuple2) computeReconstructionErrorRDD(skinnyBlockMatrix, denseVector, skinnyBlockMatrix2, rdd, d).map(new ReconstructionError$$anonfun$6(), ClassTag$.MODULE$.apply(Tuple2.class)).reduce(new ReconstructionError$$anonfun$7());
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        return scala.math.package$.MODULE$.pow(new Tuple2.mcDJ.sp(tuple2._1$mcD$sp(), tuple2._2$mcJ$sp())._1$mcD$sp() / r0._2$mcJ$sp(), 1.0d / d);
    }

    public RDD<Object> computeReconstructionErrorRDD(SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, RDD<MatrixEntry> rdd, double d) {
        long count = (rdd.count() / 2000000) + 1;
        if (m23logger().underlying().isInfoEnabled()) {
            m23logger().underlying().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"computeReconstructionError: using ", " partitions"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToLong(count)})));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        if (m23logger().underlying().isInfoEnabled()) {
            m23logger().underlying().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"computeReconstructionError: Partition size: ", " MB"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble((((2 * 2000000) * skinnyBlockMatrix.numCols()) * 8) / 1048576.0d)})));
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        int blockSize = skinnyBlockMatrix.blockSize();
        HashPartitioner hashPartitioner = new HashPartitioner((int) count);
        HashPartitioner hashPartitioner2 = new HashPartitioner((int) count);
        return RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(skinnyBlockMatrix.blocks().flatMap(new ReconstructionError$$anonfun$9(blockSize), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(DenseVector.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner).zipPartitions(RDD$.MODULE$.rddToPairRDDFunctions(rdd.map(new ReconstructionError$$anonfun$8(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(MatrixEntry.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner), new ReconstructionError$$anonfun$10(), ClassTag$.MODULE$.apply(Tuple2.class), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner2).zipPartitions(RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(skinnyBlockMatrix2.blocks().flatMap(new ReconstructionError$$anonfun$13(blockSize), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(DenseVector.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner).zipPartitions(RDD$.MODULE$.rddToPairRDDFunctions(rdd.map(new ReconstructionError$$anonfun$12(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(MatrixEntry.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner), new ReconstructionError$$anonfun$14(hashPartitioner2), ClassTag$.MODULE$.apply(Tuple2.class), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Long(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Long$.MODULE$).partitionBy(hashPartitioner2), new ReconstructionError$$anonfun$computeReconstructionErrorRDD$1(denseVector, d), ClassTag$.MODULE$.apply(Tuple2.class), ClassTag$.MODULE$.Double());
    }

    public ReconstructionError.ReconstructionErrorDistribution computeStatsForReconstructionError(SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, RDD<MatrixEntry> rdd, double d, long j, int i) {
        Predef$.MODULE$.require(d > 0.0d);
        long count = rdd.count();
        double d2 = count < j ? 1.0d : j / count;
        if (m23logger().underlying().isInfoEnabled()) {
            m23logger().underlying().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Sampling rate: ", "%"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble(d2 * 100.0d)})));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        double[] dArr = (double[]) Predef$.MODULE$.doubleArrayOps((double[]) Predef$.MODULE$.doubleArrayOps((double[]) computeReconstructionErrorRDD(skinnyBlockMatrix, denseVector, skinnyBlockMatrix2, rdd.sample(false, d2, rdd.sample$default$3()), d).collect()).map(new ReconstructionError$$anonfun$1(d), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()))).sorted(Ordering$Double$.MODULE$);
        int length = dArr.length;
        double[] dArr2 = new double[i + 1];
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).foreach$mVc$sp(new ReconstructionError$$anonfun$computeStatsForReconstructionError$1(i, dArr, length, dArr2));
        dArr2[i] = dArr[length - 1];
        double unboxToDouble = BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(dArr).sum(Numeric$DoubleIsFractional$.MODULE$)) / dArr.length;
        DoubleRef create = DoubleRef.create(0.0d);
        Predef$.MODULE$.doubleArrayOps(dArr).foreach(new ReconstructionError$$anonfun$computeStatsForReconstructionError$2(unboxToDouble, create));
        create.elem /= length - 1;
        return new ReconstructionError.ReconstructionErrorDistribution(length, unboxToDouble, create.elem, dArr2);
    }

    public double computeStatsForReconstructionError$default$5() {
        return 2.0d;
    }

    public long computeStatsForReconstructionError$default$6() {
        return 50000000L;
    }

    public int computeStatsForReconstructionError$default$7() {
        return 100;
    }

    public double computeReconstructionError(SkinnyBlockMatrix skinnyBlockMatrix, DenseVector<Object> denseVector, SkinnyBlockMatrix skinnyBlockMatrix2, double d, BlockMatrix blockMatrix, long j) {
        return computeReconstructionErrorFromEntries(skinnyBlockMatrix, denseVector, skinnyBlockMatrix2, generateEntries(blockMatrix, 1, j, true), d);
    }

    public double computeReconstructionErrorFromEntries$default$5() {
        return 2.0d;
    }

    public double computeReconstructionErrorRDD$default$5() {
        return 2.0d;
    }

    public long computeReconstructionError$default$6() {
        return 12345L;
    }

    private ReconstructionError$() {
        MODULE$ = this;
        StrictLogging.class.$init$(this);
    }
}
