package org.apache.mahout.clustering.lda.cvb;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/ModelTrainer.class */
public class ModelTrainer {
    private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);
    private final int numTopics;
    private final int numTerms;
    private TopicModel readModel;
    private TopicModel writeModel;
    private ThreadPoolExecutor threadPool;
    private BlockingQueue<Runnable> workQueue;
    private final int numTrainThreads;
    private final boolean isReadWrite;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/ModelTrainer$TrainerRunnable.class */
    public static final class TrainerRunnable implements Runnable, Callable<Double> {
        private final TopicModel readModel;
        private final TopicModel writeModel;
        private final Vector document;
        private final Vector docTopics;
        private final Matrix docTopicModel;
        private final int numDocTopicIters;

        private TrainerRunnable(TopicModel topicModel, TopicModel topicModel2, Vector vector, Vector vector2, Matrix matrix, int i) {
            this.readModel = topicModel;
            this.writeModel = topicModel2;
            this.document = vector;
            this.docTopics = vector2;
            this.docTopicModel = matrix;
            this.numDocTopicIters = i;
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < this.numDocTopicIters; i++) {
                this.readModel.trainDocTopicModel(this.document, this.docTopics, this.docTopicModel);
            }
            if (this.writeModel != null) {
                this.writeModel.update(this.docTopicModel);
            }
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() {
            run();
            return Double.valueOf(this.readModel.perplexity(this.document, this.docTopics));
        }
    }

    public ModelTrainer(TopicModel topicModel, TopicModel topicModel2, int i, int i2, int i3) {
        this.readModel = topicModel;
        this.writeModel = topicModel2;
        this.numTrainThreads = i;
        this.numTopics = i2;
        this.numTerms = i3;
        this.isReadWrite = topicModel == topicModel2;
    }

    public ModelTrainer(TopicModel topicModel, int i, int i2, int i3) {
        this(topicModel, topicModel, i, i2, i3);
    }

    public TopicModel getReadModel() {
        return this.readModel;
    }

    public void start() {
        log.info("Starting training threadpool with {} threads", Integer.valueOf(this.numTrainThreads));
        this.workQueue = new ArrayBlockingQueue(this.numTrainThreads * 10);
        this.threadPool = new ThreadPoolExecutor(this.numTrainThreads, this.numTrainThreads, 0L, TimeUnit.SECONDS, this.workQueue);
        this.threadPool.allowCoreThreadTimeOut(false);
        this.threadPool.prestartAllCoreThreads();
        this.writeModel.reset();
    }

    public void train(VectorIterable vectorIterable, VectorIterable vectorIterable2) {
        train(vectorIterable, vectorIterable2, 1);
    }

    public double calculatePerplexity(VectorIterable vectorIterable, VectorIterable vectorIterable2) {
        return calculatePerplexity(vectorIterable, vectorIterable2, 0.0d);
    }

    public double calculatePerplexity(VectorIterable vectorIterable, VectorIterable vectorIterable2, double d) {
        Iterator<MatrixSlice> it = vectorIterable.iterator();
        Iterator<MatrixSlice> it2 = vectorIterable2.iterator();
        double d2 = 0.0d;
        double d3 = 0.0d;
        while (it.hasNext() && it2.hasNext()) {
            MatrixSlice next = it.next();
            MatrixSlice next2 = it2.next();
            int index = next.index();
            Vector vector = next.vector();
            Vector vector2 = next2.vector();
            if (d == 0.0d || index % (1.0d / d) == 0.0d) {
                trainSync(vector, vector2, false, 10);
                d2 += this.readModel.perplexity(vector, vector2);
                d3 += vector.norm(1.0d);
            }
        }
        return d2 / d3;
    }

    public void train(VectorIterable vectorIterable, VectorIterable vectorIterable2, int i) {
        start();
        Iterator<MatrixSlice> it = vectorIterable.iterator();
        Iterator<MatrixSlice> it2 = vectorIterable2.iterator();
        long nanoTime = System.nanoTime();
        int i2 = 0;
        double[] dArr = new double[100];
        HashMap newHashMap = Maps.newHashMap();
        int i3 = 0;
        long nanoTime2 = System.nanoTime();
        while (it.hasNext() && it2.hasNext()) {
            i2++;
            Vector vector = it.next().vector();
            Vector vector2 = it2.next().vector();
            if (!this.isReadWrite) {
                long nanoTime3 = System.nanoTime();
                train(vector, vector2, true, i);
                if (log.isDebugEnabled()) {
                    dArr[i2 % dArr.length] = (System.nanoTime() - nanoTime3) / (1000000.0d * vector.getNumNondefaultElements());
                    if (i2 % 100 == 0) {
                        log.debug("trained {} documents in {}ms", Integer.valueOf(i2), Double.valueOf((System.nanoTime() - nanoTime) / 1000000.0d));
                        if (i2 % 500 == 0) {
                            Arrays.sort(dArr);
                            log.debug("training took median {}ms per token-instance", Double.valueOf(dArr[dArr.length / 2]));
                        }
                    }
                }
            } else if (newHashMap.size() < this.numTrainThreads) {
                newHashMap.put(vector, vector2);
                if (log.isDebugEnabled()) {
                    i3 += vector.getNumNondefaultElements();
                }
            } else {
                batchTrain(newHashMap, true, i);
                long nanoTime4 = System.nanoTime();
                log.debug("trained {} docs with {} tokens, start time {}, end time {}", Integer.valueOf(this.numTrainThreads), Integer.valueOf(i3), Long.valueOf(nanoTime2), Long.valueOf(nanoTime4));
                nanoTime2 = nanoTime4;
                i3 = 0;
            }
        }
        stop();
    }

    public void batchTrain(Map<Vector, Vector> map, boolean z, int i) {
        loop0: while (true) {
            try {
                ArrayList newArrayList = Lists.newArrayList();
                for (Map.Entry<Vector, Vector> entry : map.entrySet()) {
                    newArrayList.add(new TrainerRunnable(this.readModel, null, entry.getKey(), entry.getValue(), new SparseRowMatrix(this.numTopics, this.numTerms, true), i));
                }
                this.threadPool.invokeAll(newArrayList);
                if (!z) {
                    break;
                }
                Iterator it = newArrayList.iterator();
                while (it.hasNext()) {
                    this.writeModel.update(((TrainerRunnable) it.next()).docTopicModel);
                }
                break loop0;
            } catch (InterruptedException e) {
                log.warn("Interrupted during batch training, retrying!", (Throwable) e);
            }
        }
    }

    public void train(Vector vector, Vector vector2, boolean z, int i) {
        while (true) {
            try {
                this.workQueue.put(new TrainerRunnable(this.readModel, z ? this.writeModel : null, vector, vector2, new SparseRowMatrix(this.numTopics, this.numTerms, true), i));
                return;
            } catch (InterruptedException e) {
                log.warn("Interrupted waiting to submit document to work queue: {}", vector, e);
            }
        }
    }

    public void trainSync(Vector vector, Vector vector2, boolean z, int i) {
        new TrainerRunnable(this.readModel, z ? this.writeModel : null, vector, vector2, new SparseRowMatrix(this.numTopics, this.numTerms, true), i).run();
    }

    public double calculatePerplexity(Vector vector, Vector vector2, int i) {
        return new TrainerRunnable(this.readModel, null, vector, vector2, new SparseRowMatrix(this.numTopics, this.numTerms, true), i).call().doubleValue();
    }

    public void stop() {
        long nanoTime = System.nanoTime();
        log.info("Initiating stopping of training threadpool");
        try {
            this.threadPool.shutdown();
            if (!this.threadPool.awaitTermination(60L, TimeUnit.SECONDS)) {
                log.warn("Threadpool timed out on await termination - jobs still running!");
            }
            long nanoTime2 = System.nanoTime();
            log.info("threadpool took: {}ms", Double.valueOf((nanoTime2 - nanoTime) / 1000000.0d));
            this.readModel.stop();
            long nanoTime3 = System.nanoTime();
            log.info("readModel.stop() took {}ms", Double.valueOf((nanoTime3 - nanoTime2) / 1000000.0d));
            this.writeModel.stop();
            log.info("writeModel.stop() took {}ms", Double.valueOf((System.nanoTime() - nanoTime3) / 1000000.0d));
            TopicModel topicModel = this.writeModel;
            this.writeModel = this.readModel;
            this.readModel = topicModel;
        } catch (InterruptedException e) {
            log.error("Interrupted shutting down!", (Throwable) e);
        }
    }

    public void persist(Path path) throws IOException {
        this.readModel.persist(path, true);
    }
}
