package org.apache.mahout.clustering.lda;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
import org.apache.mahout.clustering.lda.LDAInference;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
import org.apache.mahout.fpm.pfpgrowth.PFPGrowth;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.hadoop.stochasticsvd.YtYJob;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/clustering/lda/LDADriver.class */
public final class LDADriver extends AbstractJob {
    private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
    private static final String NUM_TOPICS_OPTION = "numTopics";
    static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
    static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
    static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
    static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
    static final int LOG_LIKELIHOOD_KEY = -2;
    static final int TOPIC_SUM_KEY = -1;
    static final double OVERALL_CONVERGENCE = 1.0E-5d;
    private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
    private LDAState state = null;
    private LDAInference inference = null;
    private Iterable<Pair<Writable, VectorWritable>> trainingCorpus = null;

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new LDADriver(), strArr);
    }

    public static LDAState createState(Configuration configuration) {
        return createState(configuration, false);
    }

    public static LDAState createState(Configuration configuration, boolean z) {
        String str = configuration.get(STATE_IN_KEY);
        int parseInt = Integer.parseInt(configuration.get(NUM_TOPICS_KEY));
        int parseInt2 = Integer.parseInt(configuration.get(NUM_WORDS_KEY));
        double parseDouble = Double.parseDouble(configuration.get(TOPIC_SMOOTHING_KEY));
        Path path = new Path(str);
        DenseMatrix denseMatrix = new DenseMatrix(parseInt, parseInt2);
        double[] dArr = new double[parseInt];
        Arrays.fill(dArr, Double.NEGATIVE_INFINITY);
        double d = 0.0d;
        if (z) {
            return new LDAState(parseInt, parseInt2, parseDouble, denseMatrix, dArr, VectorSimilarityMeasure.NO_NORM);
        }
        Iterator it = new SequenceFileDirIterable(new Path(path, PFPGrowth.FILE_PATTERN), PathType.GLOB, null, null, true, configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            IntPairWritable intPairWritable = (IntPairWritable) pair.getFirst();
            DoubleWritable doubleWritable = (DoubleWritable) pair.getSecond();
            int first = intPairWritable.getFirst();
            int second = intPairWritable.getSecond();
            if (second == -1) {
                dArr[first] = doubleWritable.get();
                Preconditions.checkArgument(!Double.isInfinite(doubleWritable.get()));
            } else if (first == -2) {
                d = doubleWritable.get();
            } else {
                Preconditions.checkArgument(first >= 0, "topic should be non-negative, not %d", Integer.valueOf(first));
                Preconditions.checkArgument(second >= 0, "word should be non-negative not %d", Integer.valueOf(second));
                Preconditions.checkArgument(denseMatrix.getQuick(first, second) == VectorSimilarityMeasure.NO_NORM);
                denseMatrix.setQuick(first, second, doubleWritable.get());
                Preconditions.checkArgument(!Double.isInfinite(denseMatrix.getQuick(first, second)));
            }
        }
        return new LDAState(parseInt, parseInt2, parseDouble, denseMatrix, dArr, d);
    }

    public int run(String[] strArr) throws IOException, ClassNotFoundException, InterruptedException {
        addInputOption();
        addOutputOption();
        addOption(DefaultOptionCreator.overwriteOption().create());
        addOption(NUM_TOPICS_OPTION, RandomSeedGenerator.K, "The total number of topics in the corpus", true);
        addOption(TOPIC_SMOOTHING_OPTION, "a", "Topic smoothing parameter. Default is 50/numTopics.", "-1.0");
        addOption(DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
        if (parseArguments(strArr) == null) {
            return -1;
        }
        Path inputPath = getInputPath();
        Path outputPath = getOutputPath();
        if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
            HadoopUtil.delete(getConf(), outputPath);
        }
        int parseInt = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
        int parseInt2 = Integer.parseInt(getOption(NUM_TOPICS_OPTION));
        int determineNumberOfWordsFromFirstVector = determineNumberOfWordsFromFirstVector();
        double parseDouble = Double.parseDouble(getOption(TOPIC_SMOOTHING_OPTION));
        if (parseDouble < 1.0d) {
            parseDouble = 50.0d / parseInt2;
        }
        run(getConf(), inputPath, outputPath, parseInt2, determineNumberOfWordsFromFirstVector, parseDouble, parseInt, false);
        return 0;
    }

    private static Path getLastKnownStatePath(Configuration configuration, Path path) throws IOException {
        Path path2 = null;
        int i = Integer.MIN_VALUE;
        for (FileStatus fileStatus : FileSystem.get(configuration).globStatus(new Path(path, "state-*"))) {
            try {
                int parseInt = Integer.parseInt(fileStatus.getPath().getName().split("-")[1]);
                if (parseInt > i) {
                    i = parseInt;
                    path2 = fileStatus.getPath();
                }
            } catch (NumberFormatException e) {
                throw new IOException(e);
            }
        }
        return path2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private int determineNumberOfWordsFromFirstVector() throws IOException {
        VectorWritable vectorWritable;
        SequenceFileDirValueIterator sequenceFileDirValueIterator = new SequenceFileDirValueIterator(getInputPath(), PathType.LIST, PathFilters.logsCRCFilter(), null, true, getConf());
        do {
            try {
                if (!sequenceFileDirValueIterator.hasNext()) {
                    log.warn("can't determine number of words; no vectors in {}", getInputPath());
                    return 0;
                }
                vectorWritable = (VectorWritable) sequenceFileDirValueIterator.next();
            } finally {
                Closeables.closeQuietly(sequenceFileDirValueIterator);
            }
        } while (vectorWritable.get() == null);
        int size = vectorWritable.get().size();
        Closeables.closeQuietly(sequenceFileDirValueIterator);
        return size;
    }

    public double run(Configuration configuration, Path path, Path path2, int i, int i2, double d, int i3, boolean z) throws IOException, InterruptedException, ClassNotFoundException {
        Path path3;
        Path lastKnownStatePath = getLastKnownStatePath(configuration, path2);
        if (lastKnownStatePath == null) {
            path3 = new Path(path2, "state-0");
            writeInitialState(path3, i, i2);
        } else {
            path3 = lastKnownStatePath;
        }
        configuration.set(STATE_IN_KEY, path3.toString());
        configuration.set(NUM_TOPICS_KEY, Integer.toString(i));
        configuration.set(NUM_WORDS_KEY, Integer.toString(i2));
        configuration.set(TOPIC_SMOOTHING_KEY, Double.toString(d));
        double d2 = Double.NEGATIVE_INFINITY;
        boolean z2 = false;
        int parseInt = Integer.parseInt(path3.getName().split("-")[1]) + 1;
        while (true) {
            if ((i3 < 1 || parseInt <= i3) && !z2) {
                log.info("LDA Iteration {}", Integer.valueOf(parseInt));
                configuration.set(STATE_IN_KEY, path3.toString());
                Path path4 = new Path(path2, "state-" + parseInt);
                double runIterationSequential = z ? runIterationSequential(configuration, path, path4) : runIteration(configuration, path, path3, path4);
                double d3 = (d2 - runIterationSequential) / d2;
                log.info("Iteration {} finished. Log Likelihood: {}", Integer.valueOf(parseInt), Double.valueOf(runIterationSequential));
                log.info("(Old LL: {})", Double.valueOf(d2));
                log.info("(Rel Change: {})", Double.valueOf(d3));
                z2 = parseInt > 3 && d3 < OVERALL_CONVERGENCE;
                path3 = path4;
                d2 = runIterationSequential;
                parseInt++;
            }
        }
        if (z) {
            computeDocumentTopicProbabilitiesSequential(configuration, path, new Path(path2, "docTopics"));
        } else {
            computeDocumentTopicProbabilities(configuration, path, path3, new Path(path2, "docTopics"), i, i2, d);
        }
        return -d2;
    }

    private static void writeInitialState(Path path, int i, int i2) throws IOException {
        Configuration configuration = new Configuration();
        FileSystem fileSystem = path.getFileSystem(configuration);
        DoubleWritable doubleWritable = new DoubleWritable();
        Random random = RandomUtils.getRandom();
        for (int i3 = 0; i3 < i; i3++) {
            SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, configuration, new Path(path, YtYJob.OUTPUT_YtY + i3), IntPairWritable.class, DoubleWritable.class);
            double d = 0.0d;
            for (int i4 = 0; i4 < i2; i4++) {
                try {
                    IntPairWritable intPairWritable = new IntPairWritable(i3, i4);
                    double nextDouble = random.nextDouble() + 1.0E-8d;
                    d += nextDouble;
                    doubleWritable.set(Math.log(nextDouble));
                    writer.append(intPairWritable, doubleWritable);
                } catch (Throwable th) {
                    Closeables.closeQuietly(writer);
                    throw th;
                }
            }
            IntPairWritable intPairWritable2 = new IntPairWritable(i3, -1);
            doubleWritable.set(Math.log(d));
            writer.append(intPairWritable2, doubleWritable);
            Closeables.closeQuietly(writer);
        }
    }

    private static void writeState(Configuration configuration, LDAState lDAState, Path path) throws IOException {
        SequenceFile.Writer writer;
        FileSystem fileSystem = path.getFileSystem(configuration);
        DoubleWritable doubleWritable = new DoubleWritable();
        for (int i = 0; i < lDAState.getNumTopics(); i++) {
            writer = new SequenceFile.Writer(fileSystem, configuration, new Path(path, YtYJob.OUTPUT_YtY + i), IntPairWritable.class, DoubleWritable.class);
            for (int i2 = 0; i2 < lDAState.getNumWords(); i2++) {
                try {
                    IntPairWritable intPairWritable = new IntPairWritable(i, i2);
                    doubleWritable.set(lDAState.logProbWordGivenTopic(i2, i) + lDAState.getLogTotal(i));
                    writer.append(intPairWritable, doubleWritable);
                } finally {
                }
            }
            IntPairWritable intPairWritable2 = new IntPairWritable(i, -1);
            doubleWritable.set(lDAState.getLogTotal(i));
            writer.append(intPairWritable2, doubleWritable);
            Closeables.closeQuietly(writer);
        }
        writer = new SequenceFile.Writer(fileSystem, configuration, new Path(path, "part--2"), IntPairWritable.class, DoubleWritable.class);
        try {
            IntPairWritable intPairWritable3 = new IntPairWritable(-2, -2);
            doubleWritable.set(lDAState.getLogLikelihood());
            writer.append(intPairWritable3, doubleWritable);
            Closeables.closeQuietly(writer);
        } finally {
        }
    }

    private static double findLL(Path path, Configuration configuration) throws IOException {
        double d = 0.0d;
        for (FileStatus fileStatus : path.getFileSystem(configuration).globStatus(new Path(path, PFPGrowth.FILE_PATTERN))) {
            SequenceFileIterator sequenceFileIterator = new SequenceFileIterator(fileStatus.getPath(), true, configuration);
            while (true) {
                try {
                    if (!sequenceFileIterator.hasNext()) {
                        break;
                    }
                    Pair next = sequenceFileIterator.next();
                    if (((IntPairWritable) next.getFirst()).getFirst() == -2) {
                        d = ((DoubleWritable) next.getSecond()).get();
                        break;
                    }
                } finally {
                    Closeables.closeQuietly(sequenceFileIterator);
                }
            }
        }
        return d;
    }

    private double runIterationSequential(Configuration configuration, Path path, Path path2) throws IOException {
        if (this.state == null) {
            this.state = createState(configuration);
        }
        if (this.trainingCorpus == null) {
            Class<? extends Writable> peekAtSequenceFileForKeyType = peekAtSequenceFileForKeyType(configuration, path);
            LinkedList linkedList = new LinkedList();
            for (FileStatus fileStatus : FileSystem.get(configuration).globStatus(new Path(path, PFPGrowth.FILE_PATTERN))) {
                SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(configuration), fileStatus.getPath(), configuration);
                Writable writable = (Writable) ReflectionUtils.newInstance(peekAtSequenceFileForKeyType, configuration);
                VectorWritable vectorWritable = new VectorWritable();
                while (true) {
                    VectorWritable vectorWritable2 = vectorWritable;
                    if (reader.next(writable, vectorWritable2)) {
                        Writable writable2 = (Writable) ReflectionUtils.newInstance(peekAtSequenceFileForKeyType, configuration);
                        VectorWritable vectorWritable3 = new VectorWritable();
                        linkedList.add(new Pair(writable, vectorWritable2));
                        writable = writable2;
                        vectorWritable = vectorWritable3;
                    }
                }
            }
            this.trainingCorpus = linkedList;
        }
        if (this.inference == null) {
            this.inference = new LDAInference(this.state);
        }
        LDAState createState = createState(configuration, true);
        double d = 0.0d;
        Iterator<Pair<Writable, VectorWritable>> it = this.trainingCorpus.iterator();
        while (it.hasNext()) {
            Vector vector = it.next().getSecond().get();
            try {
                LDAInference.InferredDocument infer = this.inference.infer(vector);
                Iterator<Vector.Element> iterateNonZero = vector.iterateNonZero();
                while (iterateNonZero.hasNext()) {
                    Vector.Element next = iterateNonZero.next();
                    int index = next.index();
                    for (int i = 0; i < this.state.getNumTopics(); i++) {
                        double phi = infer.phi(i, index) + Math.log(next.get());
                        createState.updateLogProbGivenTopic(index, i, phi);
                        createState.updateLogTotals(i, phi);
                    }
                    d += infer.getLogLikelihood();
                }
            } catch (ArrayIndexOutOfBoundsException e) {
                throw new IllegalStateException("This is probably because the --numWords argument is set too small.  \n\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n\tlarger if some storage inefficiency can be tolerated.", e);
            }
        }
        createState.setLogLikelihood(d);
        writeState(configuration, createState, path2);
        this.state = createState;
        return d;
    }

    private static double runIteration(Configuration configuration, Path path, Path path2, Path path3) throws IOException, InterruptedException, ClassNotFoundException {
        configuration.set(STATE_IN_KEY, path2.toString());
        Job job = new Job(configuration, "LDA Driver running runIteration over stateIn: " + path2);
        job.setOutputKeyClass(IntPairWritable.class);
        job.setOutputValueClass(DoubleWritable.class);
        FileInputFormat.addInputPaths(job, path.toString());
        FileOutputFormat.setOutputPath(job, path3);
        job.setMapperClass(LDAWordTopicMapper.class);
        job.setReducerClass(LDAReducer.class);
        job.setCombinerClass(LDAReducer.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setJarByClass(LDADriver.class);
        if (job.waitForCompletion(true)) {
            return findLL(path3, configuration);
        }
        throw new InterruptedException("LDA Iteration failed processing " + path2);
    }

    private static void computeDocumentTopicProbabilities(Configuration configuration, Path path, Path path2, Path path3, int i, int i2, double d) throws IOException, InterruptedException, ClassNotFoundException {
        configuration.set(STATE_IN_KEY, path2.toString());
        configuration.set(NUM_TOPICS_KEY, Integer.toString(i));
        configuration.set(NUM_WORDS_KEY, Integer.toString(i2));
        configuration.set(TOPIC_SMOOTHING_KEY, Double.toString(d));
        Job job = new Job(configuration, "LDA Driver computing p(topic|doc) for all docs/topics with stateIn: " + path2);
        job.setOutputKeyClass(peekAtSequenceFileForKeyType(configuration, path));
        job.setOutputValueClass(VectorWritable.class);
        FileInputFormat.addInputPaths(job, path.toString());
        FileOutputFormat.setOutputPath(job, path3);
        job.setMapperClass(LDADocumentTopicMapper.class);
        job.setNumReduceTasks(0);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setJarByClass(LDADriver.class);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("LDA failed to compute and output document topic probabilities with: " + path2);
        }
    }

    private void computeDocumentTopicProbabilitiesSequential(Configuration configuration, Path path, Path path2) throws IOException {
        FileSystem fileSystem = path.getFileSystem(configuration);
        Class<? extends Writable> peekAtSequenceFileForKeyType = peekAtSequenceFileForKeyType(configuration, path);
        SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, configuration, path2, peekAtSequenceFileForKeyType, VectorWritable.class);
        try {
            Writable writable = (Writable) ReflectionUtils.newInstance(peekAtSequenceFileForKeyType, configuration);
            VectorWritable vectorWritable = new VectorWritable();
            Iterator<Pair<Writable, VectorWritable>> it = this.trainingCorpus.iterator();
            while (it.hasNext()) {
                try {
                    this.inference.infer(it.next().getSecond().get());
                    writer.append(writable, vectorWritable);
                } catch (ArrayIndexOutOfBoundsException e) {
                    throw new IllegalStateException("This is probably because the --numWords argument is set too small.  \n\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n\tlarger if some storage inefficiency can be tolerated.", e);
                }
            }
        } finally {
            Closeables.closeQuietly(writer);
        }
    }

    private static Class<? extends Writable> peekAtSequenceFileForKeyType(Configuration configuration, Path path) {
        try {
            return new SequenceFile.Reader(FileSystem.get(configuration), path, configuration).getKeyClass();
        } catch (IOException e) {
            return Text.class;
        }
    }
}
