package org.apache.mahout.clustering;

import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;

/* loaded from: input_file:org/apache/mahout/clustering/DirichletClusteringPolicy.class */
public class DirichletClusteringPolicy implements ClusteringPolicy {
    private Vector mixture;
    private final double alpha0;
    private final Vector totalCounts;

    public DirichletClusteringPolicy(int i, double d) {
        this.totalCounts = new DenseVector(i);
        this.alpha0 = d;
        this.mixture = UncommonDistributions.rDirichlet(this.totalCounts, d);
    }

    @Override // org.apache.mahout.clustering.ClusteringPolicy
    public Vector select(Vector vector) {
        int rMultinom = UncommonDistributions.rMultinom(vector.times(this.mixture));
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(vector.size());
        sequentialAccessSparseVector.set(rMultinom, 1.0d);
        return sequentialAccessSparseVector;
    }

    @Override // org.apache.mahout.clustering.ClusteringPolicy
    public void update(ClusterClassifier clusterClassifier) {
        for (int i = 0; i < this.totalCounts.size(); i++) {
            this.totalCounts.set(i, this.totalCounts.get(i) + clusterClassifier.getModels().get(i).getNumPoints());
        }
        this.mixture = UncommonDistributions.rDirichlet(this.totalCounts, this.alpha0);
    }
}
