package org.deeplearning4j.nn.conf.layers;

import java.beans.ConstructorProperties;
import java.util.Collection;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/RBM.class */
public class RBM extends BasePretrainNetwork {
    protected HiddenUnit hiddenUnit;
    protected VisibleUnit visibleUnit;
    protected int k;
    protected double sparsity;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/RBM$Builder.class */
    public static class Builder extends BasePretrainNetwork.Builder<Builder> {
        private HiddenUnit hiddenUnit;
        private VisibleUnit visibleUnit;
        private int k;
        private double sparsity;

        public Builder(HiddenUnit hiddenUnit, VisibleUnit visibleUnit) {
            this.hiddenUnit = HiddenUnit.BINARY;
            this.visibleUnit = VisibleUnit.BINARY;
            this.k = 1;
            this.sparsity = 0.0d;
            this.hiddenUnit = hiddenUnit;
            this.visibleUnit = visibleUnit;
        }

        public Builder() {
            this.hiddenUnit = HiddenUnit.BINARY;
            this.visibleUnit = VisibleUnit.BINARY;
            this.k = 1;
            this.sparsity = 0.0d;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public RBM build() {
            return new RBM(this);
        }

        public Builder k(int i) {
            this.k = i;
            return this;
        }

        public Builder hiddenUnit(HiddenUnit hiddenUnit) {
            this.hiddenUnit = hiddenUnit;
            return this;
        }

        public Builder visibleUnit(VisibleUnit visibleUnit) {
            this.visibleUnit = visibleUnit;
            return this;
        }

        public Builder sparsity(double d) {
            this.sparsity = d;
            return this;
        }

        @ConstructorProperties({"hiddenUnit", "visibleUnit", "k", "sparsity"})
        public Builder(HiddenUnit hiddenUnit, VisibleUnit visibleUnit, int i, double d) {
            this.hiddenUnit = HiddenUnit.BINARY;
            this.visibleUnit = VisibleUnit.BINARY;
            this.k = 1;
            this.sparsity = 0.0d;
            this.hiddenUnit = hiddenUnit;
            this.visibleUnit = visibleUnit;
            this.k = i;
            this.sparsity = d;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/RBM$HiddenUnit.class */
    public enum HiddenUnit {
        RECTIFIED,
        BINARY,
        GAUSSIAN,
        SOFTMAX,
        IDENTITY
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/RBM$VisibleUnit.class */
    public enum VisibleUnit {
        BINARY,
        GAUSSIAN,
        SOFTMAX,
        LINEAR,
        IDENTITY
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<IterationListener> collection, int i, INDArray iNDArray, boolean z) {
        org.deeplearning4j.nn.layers.feedforward.rbm.RBM rbm = new org.deeplearning4j.nn.layers.feedforward.rbm.RBM(neuralNetConfiguration);
        rbm.setListeners(collection);
        rbm.setIndex(i);
        rbm.setParamsViewArray(iNDArray);
        rbm.setParamTable(initializer().init(neuralNetConfiguration, iNDArray, z));
        rbm.setConf(neuralNetConfiguration);
        return rbm;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public ParamInitializer initializer() {
        return PretrainParamInitializer.getInstance();
    }

    private RBM(Builder builder) {
        super(builder);
        this.hiddenUnit = builder.hiddenUnit;
        this.visibleUnit = builder.visibleUnit;
        this.k = builder.k;
        this.sparsity = builder.sparsity;
    }

    public HiddenUnit getHiddenUnit() {
        return this.hiddenUnit;
    }

    public VisibleUnit getVisibleUnit() {
        return this.visibleUnit;
    }

    public int getK() {
        return this.k;
    }

    public double getSparsity() {
        return this.sparsity;
    }

    public void setHiddenUnit(HiddenUnit hiddenUnit) {
        this.hiddenUnit = hiddenUnit;
    }

    public void setVisibleUnit(VisibleUnit visibleUnit) {
        this.visibleUnit = visibleUnit;
    }

    public void setK(int i) {
        this.k = i;
    }

    public void setSparsity(double d) {
        this.sparsity = d;
    }

    public RBM() {
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "RBM(super=" + super.toString() + ", hiddenUnit=" + getHiddenUnit() + ", visibleUnit=" + getVisibleUnit() + ", k=" + getK() + ", sparsity=" + getSparsity() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof RBM)) {
            return false;
        }
        RBM rbm = (RBM) obj;
        if (!rbm.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        HiddenUnit hiddenUnit = getHiddenUnit();
        HiddenUnit hiddenUnit2 = rbm.getHiddenUnit();
        if (hiddenUnit == null) {
            if (hiddenUnit2 != null) {
                return false;
            }
        } else if (!hiddenUnit.equals(hiddenUnit2)) {
            return false;
        }
        VisibleUnit visibleUnit = getVisibleUnit();
        VisibleUnit visibleUnit2 = rbm.getVisibleUnit();
        if (visibleUnit == null) {
            if (visibleUnit2 != null) {
                return false;
            }
        } else if (!visibleUnit.equals(visibleUnit2)) {
            return false;
        }
        return getK() == rbm.getK() && Double.compare(getSparsity(), rbm.getSparsity()) == 0;
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof RBM;
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (1 * 59) + super.hashCode();
        HiddenUnit hiddenUnit = getHiddenUnit();
        int hashCode2 = (hashCode * 59) + (hiddenUnit == null ? 43 : hiddenUnit.hashCode());
        VisibleUnit visibleUnit = getVisibleUnit();
        int hashCode3 = (((hashCode2 * 59) + (visibleUnit == null ? 43 : visibleUnit.hashCode())) * 59) + getK();
        long doubleToLongBits = Double.doubleToLongBits(getSparsity());
        return (hashCode3 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }
}
