package org.apache.mahout.clustering.lda;

import java.util.Iterator;
import org.apache.commons.math.special.Gamma;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;

/* loaded from: input_file:org/apache/mahout/clustering/lda/LDAInference.class */
public class LDAInference {
    private static final double E_STEP_CONVERGENCE = 1.0E-6d;
    private static final int MAX_ITER = 20;
    private DenseMatrix phi;
    private final LDAState state;

    /* loaded from: input_file:org/apache/mahout/clustering/lda/LDAInference$InferredDocument.class */
    public static class InferredDocument {
        private final Vector wordCounts;
        private final Vector gamma;
        private final Matrix mphi;
        private final int[] columnMap;
        private final double logLikelihood;

        InferredDocument(Vector vector, Vector vector2, int[] iArr, Matrix matrix, double d) {
            this.wordCounts = vector;
            this.gamma = vector2;
            this.mphi = matrix;
            this.columnMap = iArr;
            this.logLikelihood = d;
        }

        public double phi(int i, int i2) {
            return this.mphi.getQuick(i, this.columnMap[i2]);
        }

        public Vector getWordCounts() {
            return this.wordCounts;
        }

        public Vector getGamma() {
            return this.gamma;
        }

        public double getLogLikelihood() {
            return this.logLikelihood;
        }
    }

    public LDAInference(LDAState lDAState) {
        this.state = lDAState;
    }

    public InferredDocument infer(Vector vector) {
        double zSum = vector.zSum();
        int size = vector.size();
        Vector denseVector = new DenseVector(this.state.getNumTopics());
        denseVector.assign(this.state.getTopicSmoothing() + (zSum / this.state.getNumTopics()));
        Vector denseVector2 = new DenseVector(this.state.getNumTopics());
        createPhiMatrix(size);
        Vector digammaGamma = digammaGamma(denseVector);
        int[] iArr = new int[size];
        boolean z = false;
        double d = 1.0d;
        for (int i = 0; !z && i < 20; i++) {
            denseVector2.assign(this.state.getTopicSmoothing());
            int i2 = 0;
            Iterator iterateNonZero = vector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                Vector.Element element = (Vector.Element) iterateNonZero.next();
                int index = element.index();
                Vector eStepForWord = eStepForWord(index, digammaGamma);
                this.phi.assignColumn(i2, eStepForWord);
                if (i == 0) {
                    iArr[index] = i2;
                }
                for (int i3 = 0; i3 < denseVector2.size(); i3++) {
                    denseVector2.setQuick(i3, denseVector2.getQuick(i3) + (element.get() * Math.exp(eStepForWord.getQuick(i3))));
                }
                i2++;
            }
            Vector vector2 = denseVector;
            denseVector = denseVector2;
            denseVector2 = vector2;
            digammaGamma = digammaGamma(denseVector);
            double computeLikelihood = computeLikelihood(vector, iArr, this.phi, denseVector, digammaGamma);
            z = d < VectorSimilarityMeasure.NO_NORM && (d - computeLikelihood) / d < E_STEP_CONVERGENCE;
            d = computeLikelihood;
        }
        return new InferredDocument(vector, denseVector, iArr, this.phi, d);
    }

    private Vector digammaGamma(Vector vector) {
        Vector digamma = digamma(vector);
        double digamma2 = digamma(vector.zSum());
        for (int i = 0; i < this.state.getNumTopics(); i++) {
            digamma.setQuick(i, digamma.getQuick(i) - digamma2);
        }
        return digamma;
    }

    private void createPhiMatrix(int i) {
        if (this.phi == null || this.phi.rowSize() != i) {
            this.phi = new DenseMatrix(this.state.getNumTopics(), i);
        } else {
            this.phi.assign(VectorSimilarityMeasure.NO_NORM);
        }
    }

    private double computeLikelihood(Vector vector, int[] iArr, Matrix matrix, Vector vector2, Vector vector3) {
        double logGamma = (VectorSimilarityMeasure.NO_NORM + Gamma.logGamma(this.state.getTopicSmoothing() * this.state.getNumTopics())) - (this.state.getNumTopics() * Gamma.logGamma(this.state.getTopicSmoothing()));
        for (int i = 0; i < this.state.getNumTopics(); i++) {
            double d = vector2.get(i);
            logGamma = logGamma + ((this.state.getTopicSmoothing() - d) * vector3.getQuick(i)) + Gamma.logGamma(d);
        }
        double logGamma2 = logGamma - Gamma.logGamma(vector2.zSum());
        Iterator iterateNonZero = vector.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            Vector.Element element = (Vector.Element) iterateNonZero.next();
            int index = element.index();
            double d2 = element.get();
            int i2 = iArr[index];
            for (int i3 = 0; i3 < this.state.getNumTopics(); i3++) {
                double quick = matrix.getQuick(i3, i2);
                logGamma2 += (VectorSimilarityMeasure.NO_NORM + (Math.exp(quick) * ((vector3.getQuick(i3) - quick) + this.state.logProbWordGivenTopic(index, i3)))) * d2;
            }
        }
        return logGamma2;
    }

    private Vector eStepForWord(int i, Vector vector) {
        DenseVector denseVector = new DenseVector(this.state.getNumTopics());
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.state.getNumTopics(); i2++) {
            denseVector.setQuick(i2, this.state.logProbWordGivenTopic(i, i2) + vector.getQuick(i2));
            d = LDAUtil.logSum(d, denseVector.getQuick(i2));
        }
        for (int i3 = 0; i3 < this.state.getNumTopics(); i3++) {
            denseVector.setQuick(i3, denseVector.getQuick(i3) - d);
        }
        return denseVector;
    }

    private static Vector digamma(Vector vector) {
        DenseVector denseVector = new DenseVector(vector.size());
        denseVector.assign(vector, new DoubleDoubleFunction() { // from class: org.apache.mahout.clustering.lda.LDAInference.1
            public double apply(double d, double d2) {
                return LDAInference.digamma(d2);
            }
        });
        return denseVector;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double digamma(double d) {
        double d2 = 0.0d;
        while (d <= 5.0d) {
            d2 -= 1.0d / d;
            d += 1.0d;
        }
        double d3 = 1.0d / (d * d);
        return ((d2 + Math.log(d)) - (0.5d / d)) + (d3 * ((-0.08333333333333333d) + (d3 * (0.008333333333333333d + (d3 * ((-0.003968253968253968d) + (d3 * (0.004166666666666667d + (d3 * ((-0.007575757575757576d) + (d3 * (0.021092796092796094d + (d3 * ((-0.08333333333333333d) + ((d3 * 3617.0d) / 8160.0d)))))))))))))));
    }
}
