package org.apache.mahout.classifier.bayes;

import java.util.Collection;
import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.map.OpenIntDoubleHashMap;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/bayes/InMemoryBayesDatastore.class */
public class InMemoryBayesDatastore implements Datastore {
    private static final Logger log = LoggerFactory.getLogger(InMemoryBayesDatastore.class);
    private final BayesParameters params;
    private double alphaI;
    private final OpenObjectIntHashMap<String> featureDictionary = new OpenObjectIntHashMap<>();
    private final OpenObjectIntHashMap<String> labelDictionary = new OpenObjectIntHashMap<>();
    private final OpenIntDoubleHashMap sigmaJ = new OpenIntDoubleHashMap();
    private final OpenIntDoubleHashMap sigmaK = new OpenIntDoubleHashMap();
    private final OpenIntDoubleHashMap thetaNormalizerPerLabel = new OpenIntDoubleHashMap();
    private final Matrix weightMatrix = new SparseMatrix(1, 0);
    private double thetaNormalizer = 1.0d;
    private double sigmaJsigmaK = 1.0d;

    public InMemoryBayesDatastore(BayesParameters bayesParameters) {
        this.alphaI = 1.0d;
        String basePath = bayesParameters.getBasePath();
        this.params = bayesParameters;
        bayesParameters.set("sigma_j", basePath + "/trainer-weights/Sigma_j/part-*");
        bayesParameters.set("sigma_k", basePath + "/trainer-weights/Sigma_k/part-*");
        bayesParameters.set("sigma_kSigma_j", basePath + "/trainer-weights/Sigma_kSigma_j/part-*");
        bayesParameters.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*");
        bayesParameters.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
        this.alphaI = Double.valueOf(bayesParameters.get("alpha_i", "1.0")).doubleValue();
    }

    @Override // org.apache.mahout.classifier.bayes.Datastore
    public void initialize() throws InvalidDatastoreException {
        SequenceFileModelReader.loadModel(this, this.params, new Configuration());
        for (String str : getKeys("")) {
            log.info("{} {} {} {}", new Object[]{str, Double.valueOf(this.thetaNormalizerPerLabel.get(getLabelID(str))), Double.valueOf(this.thetaNormalizer), Double.valueOf(this.thetaNormalizerPerLabel.get(getLabelID(str)) / this.thetaNormalizer)});
        }
    }

    @Override // org.apache.mahout.classifier.bayes.Datastore
    public Collection<String> getKeys(String str) throws InvalidDatastoreException {
        return this.labelDictionary.keys();
    }

    @Override // org.apache.mahout.classifier.bayes.Datastore
    public double getWeight(String str, String str2, String str3) throws InvalidDatastoreException {
        if ("weight".equals(str)) {
            return "sigma_j".equals(str3) ? this.sigmaJ.get(getFeatureID(str2)) : this.weightMatrix.getQuick(getFeatureID(str2), getLabelID(str3));
        }
        throw new InvalidDatastoreException("Matrix not found: " + str);
    }

    @Override // org.apache.mahout.classifier.bayes.Datastore
    public double getWeight(String str, String str2) throws InvalidDatastoreException {
        if ("sumWeight".equals(str)) {
            if ("sigma_jSigma_k".equals(str2)) {
                return this.sigmaJsigmaK;
            }
            if ("vocabCount".equals(str2)) {
                return this.featureDictionary.size();
            }
            throw new InvalidDatastoreException();
        }
        if ("thetaNormalizer".equals(str)) {
            return this.thetaNormalizerPerLabel.get(getLabelID(str2)) / this.thetaNormalizer;
        }
        if ("params".equals(str)) {
            if ("alpha_i".equals(str2)) {
                return this.alphaI;
            }
            throw new InvalidDatastoreException();
        }
        if ("labelWeight".equals(str)) {
            return this.sigmaK.get(getLabelID(str2));
        }
        throw new InvalidDatastoreException();
    }

    private int getFeatureID(String str) {
        if (this.featureDictionary.containsKey(str)) {
            return this.featureDictionary.get(str);
        }
        int size = this.featureDictionary.size();
        this.featureDictionary.put(str, size);
        return size;
    }

    private int getLabelID(String str) {
        if (this.labelDictionary.containsKey(str)) {
            return this.labelDictionary.get(str);
        }
        int size = this.labelDictionary.size();
        this.labelDictionary.put(str, size);
        return size;
    }

    public void loadFeatureWeight(String str, String str2, double d) {
        this.weightMatrix.setQuick(getFeatureID(str), getLabelID(str2), d);
    }

    public void setSumFeatureWeight(String str, double d) {
        this.sigmaJ.put(getFeatureID(str), d);
    }

    public void setSumLabelWeight(String str, double d) {
        this.sigmaK.put(getLabelID(str), d);
    }

    public void setThetaNormalizer(String str, double d) {
        this.thetaNormalizerPerLabel.put(getLabelID(str), d);
        this.thetaNormalizer = Math.max(this.thetaNormalizer, Math.abs(d));
    }

    public void setSigmaJSigmaK(double d) {
        this.sigmaJsigmaK = d;
    }
}
