package com.criteo.rsvd;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.svd;
import breeze.linalg.svd$reduced$;
import breeze.linalg.svd$reduced$reduced_Svd_DM_Impl$;
import com.typesafe.scalalogging.slf4j.Logger;
import com.typesafe.scalalogging.slf4j.StrictLogging;
import org.apache.spark.SparkContext;
import org.apache.spark.storage.StorageLevel$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: RSVD.scala */
/* loaded from: input_file:com/criteo/rsvd/RSVD$.class */
public final class RSVD$ implements StrictLogging {
    public static final RSVD$ MODULE$ = null;
    private final Logger logger;

    static {
        new RSVD$();
    }

    /* renamed from: logger, reason: merged with bridge method [inline-methods] */
    public Logger m16logger() {
        return this.logger;
    }

    public void com$typesafe$scalalogging$slf4j$StrictLogging$_setter_$logger_$eq(Logger logger) {
        this.logger = logger;
    }

    public RsvdResults run(BlockMatrix blockMatrix, RSVDConfig rSVDConfig, SparkContext sparkContext) {
        if (blockMatrix.blockSize() != rSVDConfig.blockSize()) {
            throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected a matrix blocksize of ", ", got ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(rSVDConfig.blockSize()), BoxesRunTime.boxToInteger(blockMatrix.blockSize())})));
        }
        if (blockMatrix.partitionWidthInBlocks() != rSVDConfig.partitionWidthInBlocks()) {
            throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected a matrix partitionWidthInBlocks of ", ", got ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(rSVDConfig.partitionWidthInBlocks()), BoxesRunTime.boxToInteger(blockMatrix.partitionWidthInBlocks())})));
        }
        if (blockMatrix.partitionHeightInBlocks() != rSVDConfig.partitionHeightInBlocks()) {
            throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Expected a matrix partitionHeightInBlocks of ", ", got ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(rSVDConfig.partitionHeightInBlocks()), BoxesRunTime.boxToInteger(blockMatrix.partitionHeightInBlocks())})));
        }
        int embeddingDim = (int) (((((rSVDConfig.embeddingDim() + rSVDConfig.oversample()) * rSVDConfig.blockSize()) * rSVDConfig.partitionHeightInBlocks()) * 8) / 1048576.0d);
        if (embeddingDim >= 2000) {
            if (m16logger().underlying().isErrorEnabled()) {
                m16logger().underlying().error(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Dense matrix partition size (", " MB) is above 2GB, which will probably make the job fail if they end up persisted on disk (cf SPARK-3151)"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(embeddingDim)})));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        } else if (embeddingDim < 1000) {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        } else if (m16logger().underlying().isWarnEnabled()) {
            m16logger().underlying().warn(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Dense matrix partition size is ", " MB, which may be too high"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(embeddingDim)})));
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
        }
        BlockMatrix persist = blockMatrix.persist(StorageLevel$.MODULE$.DISK_ONLY());
        BlockMatrix persist2 = persist.transpose().persist(StorageLevel$.MODULE$.DISK_ONLY());
        sparkContext.setJobDescription("Producing random projection basis");
        SkinnyBlockMatrix findGoodBasis = findGoodBasis(persist, persist2, SkinnyBlockMatrix$.MODULE$.randomMatrix(blockMatrix.matWidth(), rSVDConfig.embeddingDim() + rSVDConfig.oversample(), rSVDConfig.blockSize(), rSVDConfig.partitionWidthInBlocks(), sparkContext, rSVDConfig.seed()), rSVDConfig.powerIter());
        sparkContext.setJobDescription("Computing left singular vectors");
        return computeSingularVectors(persist2, findGoodBasis, rSVDConfig.embeddingDim(), rSVDConfig.computeLeftSingularVectors(), rSVDConfig.computeRightSingularVectors());
    }

    public SkinnyBlockMatrix findGoodBasis(BlockMatrix blockMatrix, BlockMatrix blockMatrix2, SkinnyBlockMatrix skinnyBlockMatrix, int i) {
        ObjectRef create = ObjectRef.create((SkinnyBlockMatrix) blockMatrix.skinnyMultiply(skinnyBlockMatrix, false).qr()._1());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).foreach$mVc$sp(new RSVD$$anonfun$findGoodBasis$1(blockMatrix, blockMatrix2, skinnyBlockMatrix, create));
        return (SkinnyBlockMatrix) ((SkinnyBlockMatrix) create.elem).qr()._1();
    }

    public RsvdResults computeSingularVectors(BlockMatrix blockMatrix, SkinnyBlockMatrix skinnyBlockMatrix, int i, boolean z, boolean z2) {
        Some some;
        Some some2;
        Tuple2<SkinnyBlockMatrix, DenseMatrix<Object>> qr = blockMatrix.skinnyMultiply(skinnyBlockMatrix, true).qr();
        if (qr == null) {
            throw new MatchError(qr);
        }
        Tuple2 tuple2 = new Tuple2((SkinnyBlockMatrix) qr._1(), (DenseMatrix) qr._2());
        SkinnyBlockMatrix skinnyBlockMatrix2 = (SkinnyBlockMatrix) tuple2._1();
        svd.SVD svd = (svd.SVD) svd$reduced$.MODULE$.apply((DenseMatrix) tuple2._2(), svd$reduced$reduced_Svd_DM_Impl$.MODULE$);
        if (svd == null) {
            throw new MatchError(svd);
        }
        Tuple3 tuple3 = new Tuple3((DenseMatrix) svd.leftVectors(), (DenseVector) svd.singularValues(), (DenseMatrix) svd.rightVectors());
        DenseMatrix<Object> denseMatrix = (DenseMatrix) tuple3._1();
        DenseVector denseVector = (DenseVector) tuple3._2();
        DenseMatrix<Object> denseMatrix2 = (DenseMatrix) ((DenseMatrix) tuple3._3()).t(DenseMatrix$.MODULE$.canTranspose());
        Utils$.MODULE$.deterministicSignsInplace(denseMatrix, denseMatrix2);
        DenseVector denseVector2 = (DenseVector) denseVector.apply(RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i), DenseVector$.MODULE$.canSlice());
        if (denseVector2.length() >= 2) {
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), denseVector2.length() - 1).foreach$mVc$sp(new RSVD$$anonfun$computeSingularVectors$1(denseVector2));
        }
        if (true == z) {
            some = new Some(skinnyBlockMatrix.singleBlockMultiply((DenseMatrix) denseMatrix2.apply(package$.MODULE$.$colon$colon(), RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i), DenseMatrix$.MODULE$.canSliceCols()), true));
        } else {
            if (false != z) {
                throw new MatchError(BoxesRunTime.boxToBoolean(z));
            }
            some = None$.MODULE$;
        }
        Some some3 = some;
        if (true == z2) {
            some2 = new Some(skinnyBlockMatrix2.singleBlockMultiply((DenseMatrix) denseMatrix.apply(package$.MODULE$.$colon$colon(), RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i), DenseMatrix$.MODULE$.canSliceCols()), true));
        } else {
            if (false != z2) {
                throw new MatchError(BoxesRunTime.boxToBoolean(z2));
            }
            some2 = None$.MODULE$;
        }
        return new RsvdResults(some3, denseVector2, some2);
    }

    private RSVD$() {
        MODULE$ = this;
        StrictLogging.class.$init$(this);
    }
}
