package org.apache.mahout.math.hadoop.stochasticsvd;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import org.apache.commons.lang.Validate;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.lib.MultipleOutputs;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRLastStep;

/* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/BtJob.class */
public final class BtJob {
    public static final String OUTPUT_Q = "Q";
    public static final String OUTPUT_BT = "part";
    public static final String OUTPUT_BBT = "bbt";
    public static final String PROP_QJOB_PATH = "ssvd.QJob.path";
    public static final String PROP_OUPTUT_BBT_PRODUCTS = "ssvd.BtJob.outputBBtProducts";
    public static final String PROP_OUTER_PROD_BLOCK_HEIGHT = "ssvd.outerProdBlockHeight";
    public static final String PROP_RHAT_BROADCAST = "ssvd.rhat.broadcast";
    static final double SPARSE_ZEROS_PCT_THRESHOLD = 0.1d;

    /* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/BtJob$BtMapper.class */
    public static class BtMapper extends Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable> {
        private QRLastStep qr;
        private int blockNum;
        private MultipleOutputs outputs;
        private Vector btRow;
        private SparseRowBlockAccumulator btCollector;
        private Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable>.Context mapContext;
        private final Deque<Closeable> closeables = new ArrayDeque();
        private final VectorWritable qRowValue = new VectorWritable();

        protected void cleanup(Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            IOUtils.close(this.closeables);
        }

        private void outputQRow(Writable writable, Writable writable2) throws IOException {
            this.outputs.getCollector(BtJob.OUTPUT_Q, (Reporter) null).collect(writable, writable2);
        }

        protected void map(Writable writable, VectorWritable vectorWritable, Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            this.mapContext = context;
            Vector vector = vectorWritable.get();
            Vector next = this.qr.next();
            int size = next.size();
            this.qRowValue.set(next);
            outputQRow(writable, this.qRowValue);
            if (this.btRow == null) {
                this.btRow = new DenseVector(size);
            }
            if (!vector.isDense()) {
                Iterator<Vector.Element> iterateNonZero = vector.iterateNonZero();
                while (iterateNonZero.hasNext()) {
                    double d = iterateNonZero.next().get();
                    for (int i = 0; i < size; i++) {
                        this.btRow.setQuick(i, d * next.getQuick(i));
                    }
                    this.btCollector.collect(Long.valueOf(r0.index()), this.btRow);
                }
                return;
            }
            int size2 = vector.size();
            for (int i2 = 0; i2 < size2; i2++) {
                double quick = vector.getQuick(i2);
                for (int i3 = 0; i3 < size; i3++) {
                    this.btRow.setQuick(i3, quick * next.getQuick(i3));
                }
                this.btCollector.collect(Long.valueOf(i2), this.btRow);
            }
        }

        protected void setup(Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            SequenceFileDirValueIterator sequenceFileDirValueIterator;
            super.setup(context);
            Path path = new Path(context.getConfiguration().get(BtJob.PROP_QJOB_PATH));
            Path path2 = new Path(path, FileOutputFormat.getUniqueFile(context, QJob.OUTPUT_QHAT, ""));
            this.blockNum = context.getTaskAttemptID().getTaskID().getId();
            SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(path2, true, context.getConfiguration());
            this.closeables.addFirst(sequenceFileValueIterator);
            if (context.getConfiguration().get(BtJob.PROP_RHAT_BROADCAST) != null) {
                Path[] localCacheFiles = DistributedCache.getLocalCacheFiles(context.getConfiguration());
                Validate.notNull(localCacheFiles, "no RHat files in distributed cache job definition");
                Configuration configuration = new Configuration();
                configuration.set("fs.default.name", "file:///");
                sequenceFileDirValueIterator = new SequenceFileDirValueIterator(localCacheFiles, SSVDSolver.PARTITION_COMPARATOR, true, configuration);
            } else {
                sequenceFileDirValueIterator = new SequenceFileDirValueIterator(new Path(path, "R-*"), PathType.GLOB, null, SSVDSolver.PARTITION_COMPARATOR, true, context.getConfiguration());
            }
            Validate.isTrue(sequenceFileDirValueIterator.hasNext(), "Empty R-hat input!");
            this.closeables.addFirst(sequenceFileDirValueIterator);
            this.outputs = new MultipleOutputs(new JobConf(context.getConfiguration()));
            this.closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(this.outputs));
            this.qr = new QRLastStep(sequenceFileValueIterator, sequenceFileDirValueIterator, this.blockNum);
            this.closeables.addFirst(this.qr);
            if (!sequenceFileDirValueIterator.hasNext()) {
                this.closeables.remove(sequenceFileDirValueIterator);
                sequenceFileDirValueIterator.close();
            }
            this.btCollector = new SparseRowBlockAccumulator(context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1), new OutputCollector<LongWritable, SparseRowBlockWritable>() { // from class: org.apache.mahout.math.hadoop.stochasticsvd.BtJob.BtMapper.1
                public void collect(LongWritable longWritable, SparseRowBlockWritable sparseRowBlockWritable) throws IOException {
                    try {
                        BtMapper.this.mapContext.write(longWritable, sparseRowBlockWritable);
                    } catch (InterruptedException e) {
                        throw new IOException("Interrupted.", e);
                    }
                }
            });
            this.closeables.addFirst(this.btCollector);
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((Writable) obj, (VectorWritable) obj2, (Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/BtJob$OuterProductCombiner.class */
    public static class OuterProductCombiner extends Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable> {
        protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
        protected final Deque<Closeable> closeables = new ArrayDeque();
        protected int blockHeight;

        protected void setup(Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            this.blockHeight = context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
        }

        protected void reduce(Writable writable, Iterable<SparseRowBlockWritable> iterable, Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            Iterator<SparseRowBlockWritable> it = iterable.iterator();
            while (it.hasNext()) {
                this.accum.plusBlock(it.next());
            }
            context.write(writable, this.accum);
            this.accum.clear();
        }

        protected void cleanup(Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable>.Context context) throws IOException, InterruptedException {
            IOUtils.close(this.closeables);
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((Writable) obj, (Iterable<SparseRowBlockWritable>) iterable, (Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/BtJob$OuterProductReducer.class */
    public static class OuterProductReducer extends Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable> {
        protected int blockHeight;
        private boolean outputBBt;
        private UpperTriangular mBBt;
        private MultipleOutputs outputs;
        protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
        protected final Deque<Closeable> closeables = new ArrayDeque();
        private final IntWritable btKey = new IntWritable();
        private final VectorWritable btValue = new VectorWritable();

        protected void setup(Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.blockHeight = context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
            this.outputBBt = context.getConfiguration().getBoolean(BtJob.PROP_OUPTUT_BBT_PRODUCTS, false);
            if (this.outputBBt) {
                int i = context.getConfiguration().getInt("ssvd.k", -1);
                int i2 = context.getConfiguration().getInt("ssvd.p", -1);
                Validate.isTrue(i > 0, "invalid k parameter");
                Validate.isTrue(i2 >= 0, "invalid p parameter");
                this.mBBt = new UpperTriangular(i + i2);
                this.outputs = new MultipleOutputs(new JobConf(context.getConfiguration()));
                this.closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(this.outputs));
            }
        }

        protected void reduce(LongWritable longWritable, Iterable<SparseRowBlockWritable> iterable, Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.accum.clear();
            Iterator<SparseRowBlockWritable> it = iterable.iterator();
            while (it.hasNext()) {
                this.accum.plusBlock(it.next());
            }
            for (int i = 0; i < this.accum.getNumRows(); i++) {
                Vector vector = this.accum.getRows()[i];
                this.btKey.set((int) ((longWritable.get() * this.blockHeight) + this.accum.getRowIndices()[i]));
                this.btValue.set(vector);
                context.write(this.btKey, this.btValue);
                if (this.outputBBt) {
                    int numRows = this.mBBt.numRows();
                    for (int i2 = 0; i2 < numRows; i2++) {
                        double d = vector.get(i2);
                        if (d != VectorSimilarityMeasure.NO_NORM) {
                            for (int i3 = i2; i3 < numRows; i3++) {
                                double d2 = vector.get(i3);
                                if (d2 != VectorSimilarityMeasure.NO_NORM) {
                                    this.mBBt.setQuick(i2, i3, this.mBBt.getQuick(i2, i3) + (d * d2));
                                }
                            }
                        }
                    }
                }
            }
        }

        protected void cleanup(Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            try {
                if (this.outputBBt) {
                    this.outputs.getCollector(BtJob.OUTPUT_BBT, (Reporter) null).collect(new IntWritable(), new VectorWritable(new DenseVector(this.mBBt.getData())));
                }
            } finally {
                IOUtils.close(this.closeables);
            }
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((LongWritable) obj, (Iterable<SparseRowBlockWritable>) iterable, (Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    private BtJob() {
    }

    public static void run(Configuration configuration, Path[] pathArr, Path path, Path path2, int i, int i2, int i3, int i4, int i5, boolean z, Class<? extends Writable> cls, boolean z2) throws ClassNotFoundException, InterruptedException, IOException {
        JobConf jobConf = new JobConf(configuration);
        MultipleOutputs.addNamedOutput(jobConf, OUTPUT_Q, SequenceFileOutputFormat.class, cls, VectorWritable.class);
        if (z2) {
            MultipleOutputs.addNamedOutput(jobConf, OUTPUT_BBT, SequenceFileOutputFormat.class, IntWritable.class, VectorWritable.class);
        }
        Job job = new Job(jobConf);
        job.setJobName("Bt-job");
        job.setJarByClass(BtJob.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setOutputFormatClass(org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
        FileInputFormat.setInputPaths(job, pathArr);
        if (i > 0) {
            FileInputFormat.setMinInputSplitSize(job, i);
        }
        FileOutputFormat.setOutputPath(job, path2);
        job.getConfiguration().set("mapreduce.output.basename", OUTPUT_BT);
        FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
        org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.setOutputCompressionType(job, SequenceFile.CompressionType.BLOCK);
        job.setMapOutputKeyClass(LongWritable.class);
        job.setMapOutputValueClass(SparseRowBlockWritable.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(VectorWritable.class);
        job.setMapperClass(BtMapper.class);
        job.setCombinerClass(OuterProductCombiner.class);
        job.setReducerClass(OuterProductReducer.class);
        job.getConfiguration().setInt("ssvd.k", i2);
        job.getConfiguration().setInt("ssvd.p", i3);
        job.getConfiguration().set(PROP_QJOB_PATH, path.toString());
        job.getConfiguration().setBoolean(PROP_OUPTUT_BBT_PRODUCTS, z2);
        job.getConfiguration().setInt(PROP_OUTER_PROD_BLOCK_HEIGHT, i4);
        job.setNumReduceTasks(i5);
        if (z) {
            job.getConfiguration().set(PROP_RHAT_BROADCAST, "y");
            FileStatus[] globStatus = FileSystem.get(configuration).globStatus(new Path(path, "R-*"));
            if (globStatus != null) {
                for (FileStatus fileStatus : globStatus) {
                    DistributedCache.addCacheFile(fileStatus.getPath().toUri(), job.getConfiguration());
                }
            }
        }
        job.submit();
        job.waitForCompletion(false);
        if (!job.isSuccessful()) {
            throw new IOException("Bt job unsuccessful.");
        }
    }
}
