package org.apache.mahout.clustering.lda;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Closeables;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.utils.clustering.ClusterDumper;
import org.apache.mahout.utils.vectors.VectorHelper;

/* loaded from: input_file:org/apache/mahout/clustering/lda/LDAPrintTopics.class */
public final class LDAPrintTopics {
    private LDAPrintTopics() {
    }

    private static void ensureQueueSize(Collection<Queue<Pair<String, Double>>> collection, int i) {
        for (int size = collection.size(); size <= i; size++) {
            collection.add(new PriorityQueue());
        }
    }

    public static void main(String[] strArr) throws Exception {
        List asList;
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        DefaultOption create = DefaultOptionCreator.inputOption().create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("dict").withRequired(true).withArgument(argumentBuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription("Dictionary to read in, in the same format as one created by org.apache.mahout.utils.vectors.lucene.Driver").withShortName("d").create();
        DefaultOption create3 = DefaultOptionCreator.outputOption().create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("words").withRequired(false).withArgument(argumentBuilder.withName("words").withMinimum(0).withMaximum(1).withDefault("20").create()).withDescription("Number of words to print").withShortName("w").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName(ClusterDumper.DICTIONARY_TYPE_OPTION).withRequired(false).withArgument(argumentBuilder.withName(ClusterDumper.DICTIONARY_TYPE_OPTION).withMinimum(1).withMaximum(1).create()).withDescription("The dictionary file type (text|sequencefile)").withShortName("dt").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
        Group create7 = groupBuilder.withName("Options").withOption(create2).withOption(create3).withOption(create4).withOption(create).withOption(create5).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(create7);
            CommandLine parse = parser.parse(strArr);
            if (parse.hasOption(create6)) {
                CommandLineUtil.printHelp(create7);
                return;
            }
            String obj = parse.getValue(create).toString();
            String obj2 = parse.getValue(create2).toString();
            int i = 20;
            if (parse.hasOption(create4)) {
                i = Integer.parseInt(parse.getValue(create4).toString());
            }
            Configuration configuration = new Configuration();
            String obj3 = parse.hasOption(create5) ? parse.getValue(create5).toString() : "text";
            if ("text".equals(obj3)) {
                asList = Arrays.asList(VectorHelper.loadTermDictionary(new File(obj2)));
            } else {
                if (!"sequencefile".equals(obj3)) {
                    throw new IllegalArgumentException("Invalid dictionary format");
                }
                asList = Arrays.asList(VectorHelper.loadTermDictionary(configuration, obj2));
            }
            List<Queue<Pair<String, Double>>> list = topWordsForTopics(obj, configuration, asList, i);
            File file = null;
            if (parse.hasOption(create3)) {
                file = new File(parse.getValue(create3).toString());
                if (!file.exists() && !file.mkdirs()) {
                    throw new IOException("Could not create directory: " + file);
                }
            }
            printTopWords(list, file);
        } catch (OptionException e) {
            CommandLineUtil.printHelp(create7);
            throw e;
        }
    }

    private static void maybeEnqueue(Queue<Pair<String, Double>> queue, String str, double d, int i) {
        if (queue.size() >= i && d > ((Double) queue.peek().getSecond()).doubleValue()) {
            queue.poll();
        }
        if (queue.size() < i) {
            queue.add(new Pair<>(str, Double.valueOf(d)));
        }
    }

    private static void printTopWords(List<Queue<Pair<String, Double>>> list, File file) throws IOException {
        for (int i = 0; i < list.size(); i++) {
            Queue<Pair<String, Double>> queue = list.get(i);
            OutputStreamWriter outputStreamWriter = null;
            boolean z = false;
            if (file != null) {
                try {
                    outputStreamWriter = new OutputStreamWriter(new FileOutputStream(new File(file, "topic_" + i)), Charsets.UTF_8);
                } finally {
                    if (z) {
                        outputStreamWriter.flush();
                    } else {
                        Closeables.closeQuietly(outputStreamWriter);
                    }
                }
            } else {
                outputStreamWriter = new OutputStreamWriter(System.out, Charsets.UTF_8);
                z = true;
                outputStreamWriter.write("Topic " + i);
                outputStreamWriter.write(10);
                outputStreamWriter.write("===========");
                outputStreamWriter.write(10);
            }
            ArrayList<Pair> newArrayListWithCapacity = Lists.newArrayListWithCapacity(queue.size());
            Iterator<Pair<String, Double>> it = queue.iterator();
            while (it.hasNext()) {
                newArrayListWithCapacity.add(it.next());
            }
            Collections.sort(newArrayListWithCapacity, new Comparator<Pair<String, Double>>() { // from class: org.apache.mahout.clustering.lda.LDAPrintTopics.1
                @Override // java.util.Comparator
                public int compare(Pair<String, Double> pair, Pair<String, Double> pair2) {
                    return ((Double) pair2.getSecond()).compareTo((Double) pair.getSecond());
                }
            });
            for (Pair pair : newArrayListWithCapacity) {
                outputStreamWriter.write(((String) pair.getFirst()) + " [p(" + ((String) pair.getFirst()) + "|topic_" + i + ") = " + pair.getSecond());
                outputStreamWriter.write(10);
            }
        }
    }

    private static List<Queue<Pair<String, Double>>> topWordsForTopics(String str, Configuration configuration, List<String> list, int i) {
        ArrayList newArrayList = Lists.newArrayList();
        HashMap newHashMap = Maps.newHashMap();
        Iterator it = new SequenceFileDirIterable(new Path(str, "part-*"), PathType.GLOB, (PathFilter) null, (Comparator) null, true, configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            IntPairWritable intPairWritable = (IntPairWritable) pair.getFirst();
            int first = intPairWritable.getFirst();
            int second = intPairWritable.getSecond();
            ensureQueueSize(newArrayList, first);
            if (second >= 0 && first >= 0) {
                double d = ((DoubleWritable) pair.getSecond()).get();
                if (newHashMap.get(Integer.valueOf(first)) == null) {
                    newHashMap.put(Integer.valueOf(first), Double.valueOf(0.0d));
                }
                newHashMap.put(Integer.valueOf(first), Double.valueOf(((Double) newHashMap.get(Integer.valueOf(first))).doubleValue() + Math.exp(d)));
                maybeEnqueue((Queue) newArrayList.get(first), list.get(second), d, i);
            }
        }
        for (int i2 = 0; i2 < newArrayList.size(); i2++) {
            Queue<Pair> queue = (Queue) newArrayList.get(i2);
            PriorityQueue priorityQueue = new PriorityQueue(queue.size());
            double doubleValue = ((Double) newHashMap.get(Integer.valueOf(i2))).doubleValue();
            for (Pair pair2 : queue) {
                priorityQueue.add(new Pair(pair2.getFirst(), Double.valueOf(Math.exp(((Double) pair2.getSecond()).doubleValue()) / doubleValue)));
            }
            newArrayList.set(i2, priorityQueue);
        }
        return newArrayList;
    }
}
