package org.deeplearning4j.plot;

import com.google.common.util.concurrent.AtomicDouble;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne.class */
public class BarnesHutTsne implements Model {
    public static final String workspaceCache = "LOOP_CACHE";
    public static final String workspaceExternal = "LOOP_EXTERNAL";
    protected int maxIter;
    protected double realMin;
    protected double initialMomentum;
    protected double finalMomentum;
    protected double minGain;
    protected double momentum;
    protected int switchMomentumIteration;
    protected boolean normalize;
    protected int stopLyingIteration;
    protected double tolerance;
    protected double learningRate;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad;
    protected double perplexity;
    protected INDArray Y;
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction;
    private boolean invert;
    private INDArray x;
    private int numDimensions;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;
    private int vpTreeWorkers;
    protected transient TrainingListener TrainingListener;
    protected WorkspaceMode workspaceMode;
    private static final Logger log = LoggerFactory.getLogger(BarnesHutTsne.class);
    protected static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.3d).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    public static final WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder().overallocationLimit(0.2d).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();
    protected boolean usePca = false;
    protected WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.2d).policyReset(ResetPolicy.BLOCK_LEFT).policyLearning(LearningPolicy.OVER_TIME).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();

    /* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne$Builder.class */
    public static class Builder {
        private int maxIter = 1000;
        private double realMin = 9.999999960041972E-13d;
        private double initialMomentum = 0.5d;
        private double finalMomentum = 0.800000011920929d;
        private double momentum = 0.5d;
        private int switchMomentumIteration = 100;
        private boolean normalize = true;
        private int stopLyingIteration = 100;
        private double tolerance = 9.999999747378752E-6d;
        private double learningRate = 0.10000000149011612d;
        private boolean useAdaGrad = false;
        private double perplexity = 30.0d;
        private double minGain = 0.10000000149011612d;
        private double theta = 0.5d;
        private boolean invert = true;
        private int numDim = 2;
        private String similarityFunction = "cosinesimilarity";
        private int vpTreeWorkers = 1;

        public Builder vpTreeWorkers(int i) {
            this.vpTreeWorkers = i;
            return this;
        }

        public Builder minGain(double d) {
            this.minGain = d;
            return this;
        }

        public Builder perplexity(double d) {
            this.perplexity = d;
            return this;
        }

        public Builder useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder tolerance(double d) {
            this.tolerance = d;
            return this;
        }

        public Builder stopLyingIteration(int i) {
            this.stopLyingIteration = i;
            return this;
        }

        public Builder normalize(boolean z) {
            this.normalize = z;
            return this;
        }

        public Builder setMaxIter(int i) {
            this.maxIter = i;
            return this;
        }

        public Builder setRealMin(double d) {
            this.realMin = d;
            return this;
        }

        public Builder setInitialMomentum(double d) {
            this.initialMomentum = d;
            return this;
        }

        public Builder setFinalMomentum(double d) {
            this.finalMomentum = d;
            return this;
        }

        public Builder setMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder setSwitchMomentumIteration(int i) {
            this.switchMomentumIteration = i;
            return this;
        }

        public Builder similarityFunction(String str) {
            this.similarityFunction = str;
            return this;
        }

        public Builder invertDistanceMetric(boolean z) {
            this.invert = z;
            return this;
        }

        public Builder theta(double d) {
            this.theta = d;
            return this;
        }

        public Builder numDimension(int i) {
            this.numDim = i;
            return this;
        }

        public BarnesHutTsne build() {
            return new BarnesHutTsne(this.numDim, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, null, this.minGain, this.vpTreeWorkers);
        }
    }

    public BarnesHutTsne(int i, String str, double d, boolean z, int i2, double d2, double d3, double d4, double d5, int i3, boolean z2, int i4, double d6, double d7, boolean z3, double d8, TrainingListener trainingListener, double d9, int i5) {
        this.maxIter = 1000;
        this.realMin = Nd4j.EPS_THRESHOLD;
        this.initialMomentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.minGain = 0.01d;
        this.momentum = this.initialMomentum;
        this.switchMomentumIteration = 100;
        this.normalize = true;
        this.stopLyingIteration = 250;
        this.tolerance = 1.0E-5d;
        this.learningRate = 500.0d;
        this.useAdaGrad = true;
        this.perplexity = 30.0d;
        this.simiarlityFunction = "cosinesimilarity";
        this.invert = true;
        this.numDimensions = 0;
        this.maxIter = i2;
        this.realMin = d2;
        this.initialMomentum = d3;
        this.finalMomentum = d4;
        this.momentum = d5;
        this.normalize = z2;
        this.useAdaGrad = z3;
        this.stopLyingIteration = i4;
        this.learningRate = d7;
        this.switchMomentumIteration = i3;
        this.tolerance = d6;
        this.perplexity = d8;
        this.minGain = d9;
        this.numDimensions = i;
        this.simiarlityFunction = str;
        this.theta = d;
        this.TrainingListener = trainingListener;
        this.invert = z;
        this.vpTreeWorkers = i5;
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String str) {
        this.simiarlityFunction = str;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean z) {
        this.invert = z;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public int getNumDimensions() {
        return this.numDimensions;
    }

    public void setNumDimensions(int i) {
        this.numDimensions = i;
    }

    public INDArray computeGaussianPerplexity(INDArray iNDArray, double d) {
        this.N = iNDArray.rows();
        int i = (int) (3.0d * d);
        if (d > i) {
            throw new IllegalStateException("Illegal k value " + i + "greater than " + d);
        }
        this.rows = Nd4j.zeros(1, this.N + 1);
        this.cols = Nd4j.zeros(1, this.N * i);
        this.vals = Nd4j.zeros(1, this.N * i);
        for (int i2 = 0; i2 < this.N; i2++) {
            this.rows.putScalar(i2 + 1, this.rows.getDouble(i2) + i);
        }
        INDArray ones = Nd4j.ones(this.N, 1);
        double log2 = FastMath.log(d);
        VPTree vPTree = new VPTree(iNDArray, this.simiarlityFunction, this.vpTreeWorkers, this.invert);
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        try {
            try {
                log.info("Calculating probabilities of data similarities...");
                for (int i3 = 0; i3 < this.N; i3++) {
                    if (i3 % 500 == 0) {
                        log.info("Handled " + i3 + " records");
                    }
                    double d2 = -1.7976931348623157E308d;
                    double d3 = Double.MAX_VALUE;
                    ArrayList arrayList = new ArrayList();
                    vPTree.search(iNDArray.slice(i3), i + 1, arrayList, new ArrayList());
                    double d4 = ones.getDouble(i3);
                    INDArray buildFromData = VPTree.buildFromData(arrayList);
                    Pair<INDArray, Double> computeGaussianKernel = computeGaussianKernel(buildFromData, ones.getDouble(i3), i);
                    INDArray iNDArray2 = (INDArray) computeGaussianKernel.getFirst();
                    double doubleValue = ((Double) computeGaussianKernel.getSecond()).doubleValue() - log2;
                    int i4 = 0;
                    boolean z = false;
                    while (!z && i4 < 200) {
                        if (doubleValue >= this.tolerance || (-doubleValue) >= this.tolerance) {
                            if (doubleValue > 0.0d) {
                                d2 = d4;
                                d4 = (d3 == Double.MAX_VALUE || d3 == -1.7976931348623157E308d) ? d4 * 2.0d : (d4 + d3) / 2.0d;
                            } else {
                                d3 = d4;
                                d4 = (d2 == -1.7976931348623157E308d || d2 == Double.MAX_VALUE) ? d4 / 2.0d : (d4 + d2) / 2.0d;
                            }
                            doubleValue = ((Double) computeGaussianKernel(buildFromData, d4, i).getSecond()).doubleValue() - log2;
                            i4++;
                        } else {
                            z = true;
                        }
                    }
                    iNDArray2.divi(iNDArray2.sum(new int[]{Integer.MAX_VALUE}));
                    INDArray create = Nd4j.create(1, i + 1);
                    for (int i5 = 0; i5 < create.length() && i5 < arrayList.size(); i5++) {
                        create.putScalar(i5, ((DataPoint) arrayList.get(i5)).getIndex());
                    }
                    for (int i6 = 0; i6 < i; i6++) {
                        this.cols.putScalar(this.rows.getInt(new int[]{i3}) + i6, create.getDouble(i6 + 1));
                        this.vals.putScalar(this.rows.getInt(new int[]{i3}) + i6, iNDArray2.getDouble(i6));
                    }
                }
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                return this.vals;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th3;
        }
    }

    public INDArray input() {
        return this.x;
    }

    public void validateInput() {
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public INDArray getParam(String str) {
        return null;
    }

    public void initParams() {
    }

    public void addListeners(TrainingListener... trainingListenerArr) {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public Map<String, INDArray> paramTable(boolean z) {
        return null;
    }

    public void setParamTable(Map<String, INDArray> map) {
    }

    public void setParam(String str, INDArray iNDArray) {
    }

    public void clear() {
    }

    public void applyConstraints(int i, int i2) {
    }

    protected Pair<Double, INDArray> gradient(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray symmetrized(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int i;
        INDArray create = Nd4j.create(this.N);
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        for (int i2 = 0; i2 < this.N; i2++) {
            try {
                try {
                    int i3 = iNDArray.getInt(new int[]{i2});
                    int i4 = iNDArray.getInt(new int[]{i2 + 1});
                    for (int i5 = i3; i5 < i4; i5++) {
                        boolean z = false;
                        for (int i6 = iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i5})}); i6 < iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i5}) + 1}); i6++) {
                            if (iNDArray2.getInt(new int[]{i6}) == i2) {
                                z = true;
                            }
                        }
                        if (z) {
                            create.putScalar(i2, create.getDouble(i2) + 1.0d);
                        } else {
                            create.putScalar(i2, create.getDouble(i2) + 1.0d);
                            create.putScalar(iNDArray2.getInt(new int[]{i5}), create.getDouble(iNDArray2.getInt(new int[]{i5})) + 1.0d);
                        }
                    }
                } finally {
                }
            } catch (Throwable th2) {
                if (notifyScopeEntered != null) {
                    if (th != null) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                throw th2;
            }
        }
        int i7 = create.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
        INDArray create2 = Nd4j.create(this.N);
        INDArray create3 = Nd4j.create(this.N + 1);
        INDArray create4 = Nd4j.create(i7);
        INDArray create5 = Nd4j.create(i7);
        for (int i8 = 0; i8 < this.N; i8++) {
            create3.putScalar(i8 + 1, create3.getDouble(i8) + create.getDouble(i8));
        }
        for (int i9 = 0; i9 < this.N; i9++) {
            for (int i10 = iNDArray.getInt(new int[]{i9}); i10 < iNDArray.getInt(new int[]{i9 + 1}); i10++) {
                boolean z2 = false;
                for (int i11 = iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}); i11 < iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}) + 1; i11++) {
                    if (iNDArray2.getInt(new int[]{i11}) == i9) {
                        z2 = true;
                        if (i9 < iNDArray2.getInt(new int[]{i10})) {
                            create4.putScalar(create3.getInt(new int[]{i9}) + create2.getInt(new int[]{i9}), iNDArray2.getInt(new int[]{i10}));
                            create4.putScalar(create3.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}) + create2.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}), i9);
                            create5.putScalar(create3.getInt(new int[]{i9}) + create2.getInt(new int[]{i9}), iNDArray3.getDouble(i10) + iNDArray3.getDouble(i11));
                            create5.putScalar(create3.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}) + create2.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}), iNDArray3.getDouble(i10) + iNDArray3.getDouble(i11));
                        }
                    }
                }
                if (!z2 && i9 < (i = iNDArray2.getInt(new int[]{i10}))) {
                    create4.putScalar(create3.getInt(new int[]{i9}) + create2.getInt(new int[]{i9}), i);
                    create4.putScalar(create3.getInt(new int[]{iNDArray2.getInt(new int[]{i10})}) + create2.getInt(new int[]{i}), i9);
                    create5.putScalar(create3.getInt(new int[]{i9}) + create2.getInt(new int[]{i9}), iNDArray3.getDouble(i10));
                    create5.putScalar(create3.getInt(new int[]{i}) + create2.getInt(new int[]{i}), iNDArray3.getDouble(i10));
                }
                if (!z2 || (z2 && i9 < iNDArray2.getInt(new int[]{i10}))) {
                    create2.putScalar(i9, create2.getInt(new int[]{i9}) + 1);
                    int i12 = iNDArray2.getInt(new int[]{i10});
                    if (i12 != i9) {
                        create2.putScalar(i12, create2.getDouble(i12) + 1.0d);
                    }
                }
            }
        }
        create5.divi(Double.valueOf(2.0d));
        if (notifyScopeEntered != null) {
            if (0 != 0) {
                try {
                    notifyScopeEntered.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                notifyScopeEntered.close();
            }
        }
        return create5;
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray iNDArray, double d, int i) {
        INDArray create = Nd4j.create(i);
        for (int i2 = 0; i2 < i; i2++) {
            create.putScalar(i2, FastMath.exp((-d) * iNDArray.getDouble(i2 + 1)));
        }
        double d2 = create.sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        double d3 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d3 += d * iNDArray.getDouble(i3 + 1) * create.getDouble(i3);
        }
        return new Pair<>(create, Double.valueOf((d3 / d2) + FastMath.log(d2)));
    }

    public void init() {
    }

    public void setListeners(Collection<TrainingListener> collection) {
    }

    public void setListeners(TrainingListener... trainingListenerArr) {
    }

    public void fit() {
        if (this.theta == 0.0d) {
            log.debug("theta == 0, using decomposed version, might be slow");
            this.Y = new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity).calculate(this.x, this.numDimensions, this.perplexity);
            return;
        }
        if (this.Y == null) {
            this.Y = Nd4j.randn(this.x.rows(), this.numDimensions, Nd4j.getRandom()).muli(Float.valueOf(0.001f));
        }
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        try {
            try {
                computeGaussianPerplexity(this.x, this.perplexity);
                this.vals = symmetrized(this.rows, this.cols, this.vals).divi(this.vals.sum(new int[]{Integer.MAX_VALUE}));
                this.vals.muli(12);
                for (int i = 0; i < this.maxIter; i++) {
                    step(this.vals, i);
                    if (i == this.switchMomentumIteration) {
                        this.momentum = this.finalMomentum;
                    }
                    if (i == this.stopLyingIteration) {
                        this.vals.divi(12);
                    }
                    if (this.TrainingListener != null) {
                        this.TrainingListener.iterationDone(this, i, 0);
                    }
                }
                if (notifyScopeEntered != null) {
                    if (0 == 0) {
                        notifyScopeEntered.close();
                        return;
                    }
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th4;
        }
    }

    public void update(Gradient gradient) {
    }

    public void step(INDArray iNDArray, int i) {
        update(gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    public void update(INDArray iNDArray, String str) {
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        try {
            try {
                this.gains = this.gains.add(Double.valueOf(0.2d)).muli(Transforms.sign(iNDArray)).neqi(Transforms.sign(this.yIncs)).addi(this.gains.mul(Double.valueOf(0.8d)).muli(Transforms.sign(iNDArray)).neqi(Transforms.sign(this.yIncs)));
                BooleanIndexing.applyWhere(this.gains, Conditions.lessThan(Double.valueOf(this.minGain)), new Value(Double.valueOf(this.minGain)));
                INDArray mul = this.gains.mul(iNDArray);
                if (this.useAdaGrad) {
                    if (this.adaGrad == null) {
                        this.adaGrad = new AdaGrad(iNDArray.shape(), this.learningRate);
                        this.adaGrad.setStateViewArray(Nd4j.zeros(iNDArray.shape()).reshape(1, mul.length()), mul.shape(), iNDArray.ordering(), true);
                    }
                    mul = this.adaGrad.getGradient(mul, 0);
                } else {
                    mul.muli(Double.valueOf(this.learningRate));
                }
                this.yIncs.muli(Double.valueOf(this.momentum)).subi(mul);
                this.Y.addi(this.yIncs);
                if (notifyScopeEntered != null) {
                    if (0 == 0) {
                        notifyScopeEntered.close();
                        return;
                    }
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th4;
        }
    }

    public void saveAsFile(List<String> list, String str) throws IOException {
        BufferedWriter bufferedWriter = null;
        try {
            bufferedWriter = new BufferedWriter(new FileWriter(new File(str)));
            for (int i = 0; i < this.Y.rows() && i < list.size(); i++) {
                String str2 = list.get(i);
                if (str2 != null) {
                    StringBuilder sb = new StringBuilder();
                    INDArray row = this.Y.getRow(i);
                    for (int i2 = 0; i2 < row.length(); i2++) {
                        sb.append(row.getDouble(i2));
                        if (i2 < row.length() - 1) {
                            sb.append(",");
                        }
                    }
                    sb.append(",");
                    sb.append(str2);
                    sb.append(" ");
                    sb.append("\n");
                    bufferedWriter.write(sb.toString());
                }
            }
            bufferedWriter.flush();
            bufferedWriter.close();
            if (bufferedWriter != null) {
                bufferedWriter.close();
            }
        } catch (Throwable th) {
            if (bufferedWriter != null) {
                bufferedWriter.close();
            }
            throw th;
        }
    }

    @Deprecated
    public void plot(INDArray iNDArray, int i, List<String> list, String str) throws IOException {
        fit(iNDArray, i);
        saveAsFile(list, str);
    }

    public double score() {
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        try {
            try {
                INDArray create = Nd4j.create(this.numDimensions);
                AtomicDouble atomicDouble = new AtomicDouble(0.0d);
                for (int i = 0; i < this.N; i++) {
                    this.tree.computeNonEdgeForces(i, this.theta, create, atomicDouble);
                }
                double d = 0.0d;
                INDArray iNDArray = this.Y;
                for (int i2 = 0; i2 < this.N; i2++) {
                    int i3 = this.rows.getInt(new int[]{i2});
                    int i4 = this.rows.getInt(new int[]{i2 + 1});
                    int i5 = i2;
                    for (int i6 = i3; i6 < i4; i6++) {
                        int i7 = this.cols.getInt(new int[]{i6});
                        create.assign(iNDArray.slice(i5));
                        create.subi(iNDArray.slice(i7));
                        d += (this.vals.getDouble(i6) * FastMath.log(this.vals.getDouble(i6) + Nd4j.EPS_THRESHOLD)) / (((1.0d / (1.0d + Transforms.pow(create, 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0))) / atomicDouble.doubleValue()) + Nd4j.EPS_THRESHOLD);
                    }
                }
                double d2 = d;
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                return d2;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th3;
        }
    }

    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
    }

    public void accumulateScore(double d) {
    }

    public INDArray params() {
        return null;
    }

    public int numParams() {
        return 0;
    }

    public int numParams(boolean z) {
        return 0;
    }

    public void setParams(INDArray iNDArray) {
    }

    public void setParamsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray getGradientsViewArray() {
        throw new UnsupportedOperationException();
    }

    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public void fit(INDArray iNDArray) {
        this.x = iNDArray;
        fit();
    }

    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        fit(iNDArray);
    }

    @Deprecated
    public void fit(INDArray iNDArray, int i) {
        this.x = iNDArray;
        this.numDimensions = i;
        fit();
    }

    public Gradient gradient() {
        MemoryWorkspace notifyScopeEntered = (this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal)).notifyScopeEntered();
        Throwable th = null;
        try {
            try {
                if (this.yIncs == null) {
                    this.yIncs = Nd4j.zeros(this.Y.shape());
                }
                if (this.gains == null) {
                    this.gains = Nd4j.ones(this.Y.shape());
                }
                AtomicDouble atomicDouble = new AtomicDouble(0.0d);
                INDArray create = Nd4j.create(this.Y.shape());
                INDArray create2 = Nd4j.create(this.Y.shape());
                if (this.tree == null) {
                    this.tree = new SpTree(this.Y);
                }
                this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, create);
                for (int i = 0; i < this.N; i++) {
                    this.tree.computeNonEdgeForces(i, this.theta, create2.slice(i), atomicDouble);
                }
                INDArray subi = create.subi(create2.divi(atomicDouble));
                DefaultGradient defaultGradient = new DefaultGradient();
                defaultGradient.gradientForVariable().put(Y_GRAD, subi);
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                return defaultGradient;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeEntered != null) {
                if (th != null) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
            throw th3;
        }
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    public int batchSize() {
        return 0;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
    }

    public INDArray getData() {
        return this.Y;
    }

    public void setData(INDArray iNDArray) {
        this.Y = iNDArray;
    }
}
