package org.apache.mahout.clustering;

import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.TimesFunction;

/* loaded from: input_file:org/apache/mahout/clustering/ClusterClassifier.class */
public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable {
    private List<Cluster> models;
    private String modelClass;

    public ClusterClassifier(List<Cluster> list) {
        this.models = list;
        this.modelClass = list.get(0).getClass().getName();
    }

    public ClusterClassifier() {
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        if (this.models.get(0) instanceof SoftCluster) {
            ArrayList newArrayList = Lists.newArrayList();
            ArrayList newArrayList2 = Lists.newArrayList();
            Iterator<Cluster> it = this.models.iterator();
            while (it.hasNext()) {
                SoftCluster softCluster = (SoftCluster) it.next();
                newArrayList.add(softCluster);
                newArrayList2.add(Double.valueOf(softCluster.getMeasure().distance(vector, softCluster.getCenter())));
            }
            return new FuzzyKMeansClusterer().computePi(newArrayList, newArrayList2);
        }
        int i = 0;
        DenseVector denseVector = new DenseVector(this.models.size());
        Iterator<Cluster> it2 = this.models.iterator();
        while (it2.hasNext()) {
            int i2 = i;
            i++;
            denseVector.set(i2, it2.next().pdf(new VectorWritable(vector)));
        }
        return denseVector.assign(new TimesFunction(), 1.0d / denseVector.zSum());
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        if (this.models.size() != 2) {
            throw new IllegalStateException();
        }
        double pdf = this.models.get(0).pdf(new VectorWritable(vector));
        return pdf / (pdf + this.models.get(1).pdf(new VectorWritable(vector)));
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.models.size();
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.models.size());
        dataOutput.writeUTF(this.modelClass);
        Iterator<Cluster> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        this.modelClass = dataInput.readUTF();
        this.models = Lists.newArrayList();
        for (int i = 0; i < readInt; i++) {
            Cluster cluster = (Cluster) ClassUtils.instantiateAs(this.modelClass, Cluster.class);
            cluster.readFields(dataInput);
            this.models.add(cluster);
        }
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    public void train(int i, Vector vector, double d) {
        this.models.get(i).observe(new VectorWritable(vector), d);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        Iterator<Cluster> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().computeParameters();
        }
    }

    public List<Cluster> getModels() {
        return this.models;
    }
}
