package org.apache.mahout.math.decomposer.lanczos;

import java.util.EnumMap;
import java.util.Map;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.PlusMult;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.matrix.DoubleMatrix1D;
import org.apache.mahout.math.matrix.DoubleMatrix2D;
import org.apache.mahout.math.matrix.linalg.EigenvalueDecomposition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/math/decomposer/lanczos/LanczosSolver.class */
public class LanczosSolver {
    private static final Logger log = LoggerFactory.getLogger(LanczosSolver.class);
    public static final double SAFE_MAX = 1.0E150d;
    private final Map<TimingSection, Long> startTimes = new EnumMap(TimingSection.class);
    private final Map<TimingSection, Long> times = new EnumMap(TimingSection.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/math/decomposer/lanczos/LanczosSolver$Scale.class */
    public static final class Scale implements DoubleFunction {
        private final double d;

        private Scale(double d) {
            this.d = d;
        }

        @Override // org.apache.mahout.math.function.DoubleFunction
        public double apply(double d) {
            return d * this.d;
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/decomposer/lanczos/LanczosSolver$TimingSection.class */
    public enum TimingSection {
        ITERATE,
        ORTHOGANLIZE,
        TRIDIAG_DECOMP,
        FINAL_EIGEN_CREATE
    }

    public void solve(LanczosState lanczosState, int i) {
        solve(lanczosState, i, false);
    }

    public void solve(LanczosState lanczosState, int i, boolean z) {
        VectorIterable corpus = lanczosState.getCorpus();
        log.info("Finding {} singular vectors of matrix with {} rows, via Lanczos", Integer.valueOf(i), Integer.valueOf(corpus.numRows()));
        int iterationNumber = lanczosState.getIterationNumber();
        Vector basisVector = lanczosState.getBasisVector(iterationNumber - 1);
        Vector basisVector2 = lanczosState.getBasisVector(iterationNumber - 2);
        double d = 0.0d;
        Matrix diagonalMatrix = lanczosState.getDiagonalMatrix();
        while (iterationNumber < i) {
            startTime(TimingSection.ITERATE);
            Vector times = z ? corpus.times(basisVector) : corpus.timesSquared(basisVector);
            log.info("{} passes through the corpus so far...", Integer.valueOf(iterationNumber));
            if (lanczosState.getScaleFactor() <= VectorSimilarityMeasure.NO_NORM) {
                lanczosState.setScaleFactor(calculateScaleFactor(times));
            }
            times.assign(new Scale(1.0d / lanczosState.getScaleFactor()));
            if (basisVector2 != null) {
                times.assign(basisVector2, new PlusMult(-d));
            }
            double dot = basisVector.dot(times);
            times.assign(basisVector, new PlusMult(-dot));
            endTime(TimingSection.ITERATE);
            startTime(TimingSection.ORTHOGANLIZE);
            orthoganalizeAgainstAllButLast(times, lanczosState);
            endTime(TimingSection.ORTHOGANLIZE);
            d = times.norm(2.0d);
            if (outOfRange(d) || outOfRange(dot)) {
                log.warn("Lanczos parameters out of range: alpha = {}, beta = {}.  Bailing out early!", Double.valueOf(dot), Double.valueOf(d));
                break;
            }
            times.assign(new Scale(1.0d / d));
            lanczosState.setBasisVector(iterationNumber, times);
            basisVector2 = basisVector;
            basisVector = times;
            diagonalMatrix.set(iterationNumber - 1, iterationNumber - 1, dot);
            if (iterationNumber < i - 1) {
                diagonalMatrix.set(iterationNumber - 1, iterationNumber, d);
                diagonalMatrix.set(iterationNumber, iterationNumber - 1, d);
            }
            iterationNumber++;
            lanczosState.setIterationNumber(iterationNumber);
        }
        startTime(TimingSection.TRIDIAG_DECOMP);
        log.info("Lanczos iteration complete - now to diagonalize the tri-diagonal auxiliary matrix.");
        EigenvalueDecomposition eigenvalueDecomposition = new EigenvalueDecomposition(diagonalMatrix);
        DoubleMatrix2D v = eigenvalueDecomposition.getV();
        DoubleMatrix1D realEigenvalues = eigenvalueDecomposition.getRealEigenvalues();
        endTime(TimingSection.TRIDIAG_DECOMP);
        startTime(TimingSection.FINAL_EIGEN_CREATE);
        for (int i2 = 0; i2 < iterationNumber; i2++) {
            Vector vector = null;
            DoubleMatrix1D viewColumn = v.viewColumn((iterationNumber - i2) - 1);
            int min = Math.min(viewColumn.size(), lanczosState.getBasisSize());
            for (int i3 = 0; i3 < min; i3++) {
                double d2 = viewColumn.get(i3);
                Vector basisVector3 = lanczosState.getBasisVector(i3);
                if (vector == null) {
                    vector = basisVector3.like();
                }
                vector.assign(basisVector3, new PlusMult(d2));
            }
            lanczosState.setRightSingularVector(i2, vector.normalize());
            double scaleFactor = realEigenvalues.get(i2) * lanczosState.getScaleFactor();
            if (!z) {
                scaleFactor = Math.sqrt(scaleFactor);
            }
            log.info("Eigenvector {} found with eigenvalue {}", Integer.valueOf(i2), Double.valueOf(scaleFactor));
            lanczosState.setSingularValue(i2, scaleFactor);
        }
        log.info("LanczosSolver finished.");
        endTime(TimingSection.FINAL_EIGEN_CREATE);
    }

    protected double calculateScaleFactor(Vector vector) {
        return vector.norm(2.0d);
    }

    private static boolean outOfRange(double d) {
        return Double.isNaN(d) || d > 1.0E150d || (-d) > 1.0E150d;
    }

    protected void orthoganalizeAgainstAllButLast(Vector vector, LanczosState lanczosState) {
        for (int i = 0; i < lanczosState.getIterationNumber(); i++) {
            Vector basisVector = lanczosState.getBasisVector(i);
            if (basisVector != null) {
                double dot = vector.dot(basisVector);
                if (dot != VectorSimilarityMeasure.NO_NORM) {
                    vector.assign(basisVector, new PlusMult(-dot));
                }
            }
        }
    }

    private void startTime(TimingSection timingSection) {
        this.startTimes.put(timingSection, Long.valueOf(System.nanoTime()));
    }

    private void endTime(TimingSection timingSection) {
        if (!this.times.containsKey(timingSection)) {
            this.times.put(timingSection, 0L);
        }
        this.times.put(timingSection, Long.valueOf((this.times.get(timingSection).longValue() + System.nanoTime()) - this.startTimes.get(timingSection).longValue()));
    }
}
