package org.apache.mahout.math.hadoop.similarity.cooccurrence;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
import org.apache.mahout.math.map.OpenIntIntHashMap;

/* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.class */
public class RowSimilarityJob extends AbstractJob {
    public static final double NO_THRESHOLD = Double.MIN_VALUE;
    private static final String SIMILARITY_CLASSNAME = RowSimilarityJob.class + ".distributedSimilarityClassname";
    private static final String NUMBER_OF_COLUMNS = RowSimilarityJob.class + ".numberOfColumns";
    private static final String MAX_SIMILARITIES_PER_ROW = RowSimilarityJob.class + ".maxSimilaritiesPerRow";
    private static final String EXCLUDE_SELF_SIMILARITY = RowSimilarityJob.class + ".excludeSelfSimilarity";
    private static final String THRESHOLD = RowSimilarityJob.class + ".threshold";
    private static final String NORMS_PATH = RowSimilarityJob.class + ".normsPath";
    private static final String MAXVALUES_PATH = RowSimilarityJob.class + ".maxWeightsPath";
    private static final String NUM_NON_ZERO_ENTRIES_PATH = RowSimilarityJob.class + ".nonZeroEntriesPath";
    private static final int DEFAULT_MAX_SIMILARITIES_PER_ROW = 100;
    private static final int NORM_VECTOR_MARKER = Integer.MIN_VALUE;
    private static final int MAXVALUE_VECTOR_MARKER = -2147483647;
    private static final int NUM_NON_ZERO_ENTRIES_VECTOR_MARKER = -2147483646;

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$CooccurrencesMapper.class */
    public static class CooccurrencesMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private VectorSimilarityMeasure similarity;
        private OpenIntIntHashMap numNonZeroEntries;
        private Vector maxValues;
        private double threshold;
        private static final Comparator<Vector.Element> BY_INDEX = new Comparator<Vector.Element>() { // from class: org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob.CooccurrencesMapper.1
            @Override // java.util.Comparator
            public int compare(Vector.Element element, Vector.Element element2) {
                return Ints.compare(element.index(), element2.index());
            }
        };

        protected void setup(Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.similarity = (VectorSimilarityMeasure) ClassUtils.instantiateAs(context.getConfiguration().get(RowSimilarityJob.SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
            this.numNonZeroEntries = Vectors.readAsIntMap(new Path(context.getConfiguration().get(RowSimilarityJob.NUM_NON_ZERO_ENTRIES_PATH)), context.getConfiguration());
            this.maxValues = Vectors.read(new Path(context.getConfiguration().get(RowSimilarityJob.MAXVALUES_PATH)), context.getConfiguration());
            this.threshold = Double.parseDouble(context.getConfiguration().get(RowSimilarityJob.THRESHOLD));
        }

        private boolean consider(Vector.Element element, Vector.Element element2) {
            return this.similarity.consider(this.numNonZeroEntries.get(element.index()), this.numNonZeroEntries.get(element2.index()), this.maxValues.get(element.index()), this.maxValues.get(element2.index()), this.threshold);
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Vector.Element[] array = Vectors.toArray(vectorWritable);
            Arrays.sort(array, BY_INDEX);
            int i = 0;
            int i2 = 0;
            for (int i3 = 0; i3 < array.length; i3++) {
                Vector.Element element = array[i3];
                RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
                for (int i4 = i3; i4 < array.length; i4++) {
                    Vector.Element element2 = array[i4];
                    if (this.threshold == Double.MIN_VALUE || consider(element, element2)) {
                        randomAccessSparseVector.setQuick(element2.index(), this.similarity.aggregate(element.get(), element2.get()));
                        i++;
                    } else {
                        i2++;
                    }
                }
                context.write(new IntWritable(element.index()), new VectorWritable((Vector) randomAccessSparseVector));
            }
            context.getCounter(Counters.COOCCURRENCES).increment(i);
            context.getCounter(Counters.PRUNED_COOCCURRENCES).increment(i2);
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$Counters.class */
    public enum Counters {
        ROWS,
        COOCCURRENCES,
        PRUNED_COOCCURRENCES
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$MergeToTopKSimilaritiesReducer.class */
    public static class MergeToTopKSimilaritiesReducer extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private int maxSimilaritiesPerRow;

        protected void setup(Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.maxSimilaritiesPerRow = context.getConfiguration().getInt(RowSimilarityJob.MAX_SIMILARITIES_PER_ROW, 0);
            Preconditions.checkArgument(this.maxSimilaritiesPerRow > 0, "Incorrect maximum number of similarities per row!");
        }

        protected void reduce(IntWritable intWritable, Iterable<VectorWritable> iterable, Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            context.write(intWritable, new VectorWritable(Vectors.topKElements(this.maxSimilaritiesPerRow, Vectors.merge(iterable))));
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((IntWritable) obj, (Iterable<VectorWritable>) iterable, (Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$MergeVectorsCombiner.class */
    private static class MergeVectorsCombiner extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private MergeVectorsCombiner() {
        }

        protected void reduce(IntWritable intWritable, Iterable<VectorWritable> iterable, Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            context.write(intWritable, new VectorWritable(Vectors.merge(iterable)));
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((IntWritable) obj, (Iterable<VectorWritable>) iterable, (Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$MergeVectorsReducer.class */
    public static class MergeVectorsReducer extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private Path normsPath;
        private Path numNonZeroEntriesPath;
        private Path maxValuesPath;

        protected void setup(Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.normsPath = new Path(context.getConfiguration().get(RowSimilarityJob.NORMS_PATH));
            this.numNonZeroEntriesPath = new Path(context.getConfiguration().get(RowSimilarityJob.NUM_NON_ZERO_ENTRIES_PATH));
            this.maxValuesPath = new Path(context.getConfiguration().get(RowSimilarityJob.MAXVALUES_PATH));
        }

        protected void reduce(IntWritable intWritable, Iterable<VectorWritable> iterable, Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Vector merge = Vectors.merge(iterable);
            if (intWritable.get() == RowSimilarityJob.NORM_VECTOR_MARKER) {
                Vectors.write(merge, this.normsPath, context.getConfiguration());
                return;
            }
            if (intWritable.get() == RowSimilarityJob.MAXVALUE_VECTOR_MARKER) {
                Vectors.write(merge, this.maxValuesPath, context.getConfiguration());
            } else if (intWritable.get() == RowSimilarityJob.NUM_NON_ZERO_ENTRIES_VECTOR_MARKER) {
                Vectors.write(merge, this.numNonZeroEntriesPath, context.getConfiguration(), true);
            } else {
                context.write(intWritable, new VectorWritable(merge));
            }
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((IntWritable) obj, (Iterable<VectorWritable>) iterable, (Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$SimilarityReducer.class */
    public static class SimilarityReducer extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private VectorSimilarityMeasure similarity;
        private int numberOfColumns;
        private boolean excludeSelfSimilarity;
        private Vector norms;
        private double treshold;

        protected void setup(Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.similarity = (VectorSimilarityMeasure) ClassUtils.instantiateAs(context.getConfiguration().get(RowSimilarityJob.SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
            this.numberOfColumns = context.getConfiguration().getInt(RowSimilarityJob.NUMBER_OF_COLUMNS, -1);
            Preconditions.checkArgument(this.numberOfColumns > 0, "Incorrect number of columns!");
            this.excludeSelfSimilarity = context.getConfiguration().getBoolean(RowSimilarityJob.EXCLUDE_SELF_SIMILARITY, false);
            this.norms = Vectors.read(new Path(context.getConfiguration().get(RowSimilarityJob.NORMS_PATH)), context.getConfiguration());
            this.treshold = Double.parseDouble(context.getConfiguration().get(RowSimilarityJob.THRESHOLD));
        }

        protected void reduce(IntWritable intWritable, Iterable<VectorWritable> iterable, Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Iterator<VectorWritable> it = iterable.iterator();
            Vector vector = it.next().get();
            while (it.hasNext()) {
                for (Vector.Element element : it.next().get().nonZeroes()) {
                    vector.setQuick(element.index(), vector.getQuick(element.index()) + element.get());
                }
            }
            Vector like = vector.like();
            double quick = this.norms.getQuick(intWritable.get());
            for (Vector.Element element2 : vector.nonZeroes()) {
                double similarity = this.similarity.similarity(element2.get(), quick, this.norms.getQuick(element2.index()), this.numberOfColumns);
                if (similarity >= this.treshold) {
                    like.set(element2.index(), similarity);
                }
            }
            if (this.excludeSelfSimilarity) {
                like.setQuick(intWritable.get(), VectorSimilarityMeasure.NO_NORM);
            }
            context.write(intWritable, new VectorWritable(like));
        }

        protected /* bridge */ /* synthetic */ void reduce(Object obj, Iterable iterable, Reducer.Context context) throws IOException, InterruptedException {
            reduce((IntWritable) obj, (Iterable<VectorWritable>) iterable, (Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$UnsymmetrifyMapper.class */
    public static class UnsymmetrifyMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private int maxSimilaritiesPerRow;

        protected void setup(Mapper.Context context) throws IOException, InterruptedException {
            this.maxSimilaritiesPerRow = context.getConfiguration().getInt(RowSimilarityJob.MAX_SIMILARITIES_PER_ROW, 0);
            Preconditions.checkArgument(this.maxSimilaritiesPerRow > 0, "Incorrect maximum number of similarities per row!");
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Vector vector = vectorWritable.get();
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(vector.size(), 1);
            TopElementsQueue topElementsQueue = new TopElementsQueue(this.maxSimilaritiesPerRow);
            for (Vector.Element element : vector.nonZeroes()) {
                MutableElement mutableElement = (MutableElement) topElementsQueue.top();
                double d = element.get();
                if (d > mutableElement.get()) {
                    mutableElement.setIndex(element.index());
                    mutableElement.set(d);
                    topElementsQueue.updateTop();
                }
                randomAccessSparseVector.setQuick(intWritable.get(), d);
                context.write(new IntWritable(element.index()), new VectorWritable((Vector) randomAccessSparseVector));
                randomAccessSparseVector.setQuick(intWritable.get(), VectorSimilarityMeasure.NO_NORM);
            }
            RandomAccessSparseVector randomAccessSparseVector2 = new RandomAccessSparseVector(vector.size(), this.maxSimilaritiesPerRow);
            for (MutableElement mutableElement2 : topElementsQueue.getTopElements()) {
                randomAccessSparseVector2.setQuick(mutableElement2.index(), mutableElement2.get());
            }
            context.write(intWritable, new VectorWritable((Vector) randomAccessSparseVector2));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob$VectorNormMapper.class */
    public static class VectorNormMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private VectorSimilarityMeasure similarity;
        private Vector norms;
        private Vector nonZeroEntries;
        private Vector maxValues;
        private double threshold;

        protected void setup(Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            this.similarity = (VectorSimilarityMeasure) ClassUtils.instantiateAs(context.getConfiguration().get(RowSimilarityJob.SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
            this.norms = new RandomAccessSparseVector(Integer.MAX_VALUE);
            this.nonZeroEntries = new RandomAccessSparseVector(Integer.MAX_VALUE);
            this.maxValues = new RandomAccessSparseVector(Integer.MAX_VALUE);
            this.threshold = Double.parseDouble(context.getConfiguration().get(RowSimilarityJob.THRESHOLD));
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Vector normalize = this.similarity.normalize(vectorWritable.get());
            int i = 0;
            double d = Double.MIN_VALUE;
            for (Vector.Element element : normalize.nonZeroes()) {
                RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
                randomAccessSparseVector.setQuick(intWritable.get(), element.get());
                context.write(new IntWritable(element.index()), new VectorWritable((Vector) randomAccessSparseVector));
                i++;
                if (d < element.get()) {
                    d = element.get();
                }
            }
            if (this.threshold != Double.MIN_VALUE) {
                this.nonZeroEntries.setQuick(intWritable.get(), i);
                this.maxValues.setQuick(intWritable.get(), d);
            }
            this.norms.setQuick(intWritable.get(), this.similarity.norm(normalize));
            context.getCounter(Counters.ROWS).increment(1L);
        }

        protected void cleanup(Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            super.cleanup(context);
            context.write(new IntWritable(RowSimilarityJob.NORM_VECTOR_MARKER), new VectorWritable(this.norms));
            context.write(new IntWritable(RowSimilarityJob.NUM_NON_ZERO_ENTRIES_VECTOR_MARKER), new VectorWritable(this.nonZeroEntries));
            context.write(new IntWritable(RowSimilarityJob.MAXVALUE_VECTOR_MARKER), new VectorWritable(this.maxValues));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new RowSimilarityJob(), strArr);
    }

    public int run(String[] strArr) throws Exception {
        String str;
        addInputOption();
        addOutputOption();
        addOption("numberOfColumns", "r", "Number of columns in the input matrix", false);
        addOption("similarityClassname", "s", "Name of distributed similarity class to instantiate, alternatively use one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
        addOption("maxSimilaritiesPerRow", FuzzyKMeansDriver.M_OPTION, "Number of maximum similarities per row (default: 100)", String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ROW));
        addOption("excludeSelfSimilarity", "ess", "compute similarity of rows to themselves?", String.valueOf(false));
        addOption(DefaultOptionCreator.THRESHOLD_OPTION, "tr", "discard row pairs with a similarity value below this", false);
        addOption(DefaultOptionCreator.overwriteOption().create());
        Map<String, List<String>> parseArguments = parseArguments(strArr);
        if (parseArguments == null) {
            return -1;
        }
        int parseInt = hasOption("numberOfColumns") ? Integer.parseInt(getOption("numberOfColumns")) : getDimensions(getInputPath());
        String option = getOption("similarityClassname");
        try {
            str = VectorSimilarityMeasures.valueOf(option).getClassname();
        } catch (IllegalArgumentException e) {
            str = option;
        }
        if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
            HadoopUtil.delete(getConf(), getTempPath());
            HadoopUtil.delete(getConf(), getOutputPath());
        }
        int parseInt2 = Integer.parseInt(getOption("maxSimilaritiesPerRow"));
        boolean parseBoolean = Boolean.parseBoolean(getOption("excludeSelfSimilarity"));
        double parseDouble = hasOption(DefaultOptionCreator.THRESHOLD_OPTION) ? Double.parseDouble(getOption(DefaultOptionCreator.THRESHOLD_OPTION)) : Double.MIN_VALUE;
        Path tempPath = getTempPath(TrainNaiveBayesJob.WEIGHTS);
        Path tempPath2 = getTempPath("norms.bin");
        Path tempPath3 = getTempPath("numNonZeroEntries.bin");
        Path tempPath4 = getTempPath("maxValues.bin");
        Path tempPath5 = getTempPath("pairwiseSimilarity");
        AtomicInteger atomicInteger = new AtomicInteger();
        if (shouldRunNextPhase(parseArguments, atomicInteger)) {
            Job prepareJob = prepareJob(getInputPath(), tempPath, VectorNormMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
            prepareJob.setCombinerClass(MergeVectorsCombiner.class);
            Configuration configuration = prepareJob.getConfiguration();
            configuration.set(THRESHOLD, String.valueOf(parseDouble));
            configuration.set(NORMS_PATH, tempPath2.toString());
            configuration.set(NUM_NON_ZERO_ENTRIES_PATH, tempPath3.toString());
            configuration.set(MAXVALUES_PATH, tempPath4.toString());
            configuration.set(SIMILARITY_CLASSNAME, str);
            if (!prepareJob.waitForCompletion(true)) {
                return -1;
            }
        }
        if (shouldRunNextPhase(parseArguments, atomicInteger)) {
            Job prepareJob2 = prepareJob(tempPath, tempPath5, CooccurrencesMapper.class, IntWritable.class, VectorWritable.class, SimilarityReducer.class, IntWritable.class, VectorWritable.class);
            prepareJob2.setCombinerClass(VectorSumReducer.class);
            Configuration configuration2 = prepareJob2.getConfiguration();
            configuration2.set(THRESHOLD, String.valueOf(parseDouble));
            configuration2.set(NORMS_PATH, tempPath2.toString());
            configuration2.set(NUM_NON_ZERO_ENTRIES_PATH, tempPath3.toString());
            configuration2.set(MAXVALUES_PATH, tempPath4.toString());
            configuration2.set(SIMILARITY_CLASSNAME, str);
            configuration2.setInt(NUMBER_OF_COLUMNS, parseInt);
            configuration2.setBoolean(EXCLUDE_SELF_SIMILARITY, parseBoolean);
            if (!prepareJob2.waitForCompletion(true)) {
                return -1;
            }
        }
        if (!shouldRunNextPhase(parseArguments, atomicInteger)) {
            return 0;
        }
        Job prepareJob3 = prepareJob(tempPath5, getOutputPath(), UnsymmetrifyMapper.class, IntWritable.class, VectorWritable.class, MergeToTopKSimilaritiesReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob3.setCombinerClass(MergeToTopKSimilaritiesReducer.class);
        prepareJob3.getConfiguration().setInt(MAX_SIMILARITIES_PER_ROW, parseInt2);
        return !prepareJob3.waitForCompletion(true) ? -1 : 0;
    }
}
