package smile.stat.distribution;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.Collectors;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/stat/distribution/MultivariateMixture.class */
public class MultivariateMixture implements MultivariateDistribution {
    private static final long serialVersionUID = 2;
    public final Component[] components;

    /* loaded from: input_file:smile/stat/distribution/MultivariateMixture$Component.class */
    public static class Component implements Serializable {
        private static final long serialVersionUID = 2;
        public final double priori;
        public final MultivariateDistribution distribution;

        public Component(double d, MultivariateDistribution multivariateDistribution) {
            this.priori = d;
            this.distribution = multivariateDistribution;
        }
    }

    public MultivariateMixture(Component... componentArr) {
        if (componentArr.length == 0) {
            throw new IllegalStateException("Empty mixture!");
        }
        this.components = componentArr;
    }

    public double[] posteriori(double[] dArr) {
        int length = this.components.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            Component component = this.components[i];
            dArr2[i] = component.priori * component.distribution.p(dArr);
        }
        double sum = MathEx.sum(dArr2);
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / sum;
        }
        return dArr2;
    }

    public int map(double[] dArr) {
        int length = this.components.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            Component component = this.components[i];
            dArr2[i] = component.priori * component.distribution.p(dArr);
        }
        return MathEx.whichMax(dArr2);
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public double[] mean() {
        double d = this.components[0].priori;
        double[] mean = this.components[0].distribution.mean();
        double[] dArr = new double[mean.length];
        for (int i = 0; i < mean.length; i++) {
            dArr[i] = d * mean[i];
        }
        for (int i2 = 1; i2 < this.components.length; i2++) {
            double d2 = this.components[i2].priori;
            double[] mean2 = this.components[i2].distribution.mean();
            for (int i3 = 0; i3 < mean2.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (d2 * mean2[i3]);
            }
        }
        return dArr;
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public Matrix cov() {
        double d = this.components[0].priori;
        Matrix cov = this.components[0].distribution.cov();
        int nrow = cov.nrow();
        int ncol = cov.ncol();
        Matrix matrix = new Matrix(nrow, ncol);
        for (int i = 0; i < nrow; i++) {
            for (int i2 = 0; i2 < ncol; i2++) {
                matrix.set(i, i2, d * d * cov.get(i, i2));
            }
        }
        for (int i3 = 1; i3 < this.components.length; i3++) {
            double d2 = this.components[i3].priori;
            Matrix cov2 = this.components[i3].distribution.cov();
            for (int i4 = 0; i4 < nrow; i4++) {
                for (int i5 = 0; i5 < ncol; i5++) {
                    matrix.add(i4, i5, d2 * d2 * cov2.get(i4, i5));
                }
            }
        }
        return matrix;
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public double entropy() {
        throw new UnsupportedOperationException("Mixture does not support entropy()");
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public double p(double[] dArr) {
        double d = 0.0d;
        for (Component component : this.components) {
            d += component.priori * component.distribution.p(dArr);
        }
        return d;
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public double logp(double[] dArr) {
        return Math.log(p(dArr));
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public double cdf(double[] dArr) {
        double d = 0.0d;
        for (Component component : this.components) {
            d += component.priori * component.distribution.cdf(dArr);
        }
        return d;
    }

    @Override // smile.stat.distribution.MultivariateDistribution
    public int length() {
        int length = this.components.length - 1;
        for (Component component : this.components) {
            length += component.distribution.length();
        }
        return length;
    }

    public int size() {
        return this.components.length;
    }

    public double bic(double[][] dArr) {
        int length = dArr.length;
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            double p = p(dArr2);
            if (p > 0.0d) {
                d += Math.log(p);
            }
        }
        return d - ((0.5d * length()) * Math.log(length));
    }

    public String toString() {
        return (String) Arrays.stream(this.components).map(component -> {
            return String.format("%.2f x %s", Double.valueOf(component.priori), component.distribution);
        }).collect(Collectors.joining(" + ", String.format("MultivariateMixture(%d)[", Integer.valueOf(this.components.length)), "]"));
    }
}
