package org.deeplearning4j.nn.params;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/BidirectionalParamInitializer.class */
public class BidirectionalParamInitializer implements ParamInitializer {
    public static final String FORWARD_PREFIX = "f";
    public static final String BACKWARD_PREFIX = "b";
    private final Bidirectional layer;
    private final BaseRecurrentLayer underlying;
    private List<String> paramKeys;
    private List<String> weightKeys;
    private List<String> biasKeys;

    public BidirectionalParamInitializer(Bidirectional bidirectional) {
        this.layer = bidirectional;
        this.underlying = underlying(bidirectional);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(Layer layer) {
        return 2 * underlying(layer).initializer().numParams(underlying(layer));
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        if (this.paramKeys == null) {
            BaseRecurrentLayer underlying = underlying(layer);
            this.paramKeys = withPrefixes(underlying.initializer().paramKeys(underlying));
        }
        return this.paramKeys;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        if (this.weightKeys == null) {
            BaseRecurrentLayer underlying = underlying(layer);
            this.weightKeys = withPrefixes(underlying.initializer().weightKeys(underlying));
        }
        return this.weightKeys;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        if (this.biasKeys == null) {
            BaseRecurrentLayer underlying = underlying(layer);
            this.biasKeys = withPrefixes(underlying.initializer().weightKeys(underlying));
        }
        return this.biasKeys;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return weightKeys(this.layer).contains(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return biasKeys(this.layer).contains(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        int length = iNDArray.length() / 2;
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, length)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(length, 2 * length)});
        neuralNetConfiguration.clearVariables();
        NeuralNetConfiguration m33clone = neuralNetConfiguration.m33clone();
        NeuralNetConfiguration m33clone2 = neuralNetConfiguration.m33clone();
        m33clone.setLayer(this.underlying);
        m33clone2.setLayer(this.underlying);
        Map<String, INDArray> init = this.underlying.initializer().init(m33clone, iNDArray2, z);
        Map<String, INDArray> init2 = this.underlying.initializer().init(m33clone2, iNDArray3, z);
        Map<String, Double> addPrefixes = addPrefixes(m33clone.getL1ByParam(), m33clone2.getL1ByParam());
        Map<String, Double> addPrefixes2 = addPrefixes(m33clone.getL2ByParam(), m33clone2.getL2ByParam());
        List<String> addPrefixes3 = addPrefixes(m33clone.getVariables(), m33clone2.getVariables());
        neuralNetConfiguration.setL1ByParam(addPrefixes);
        neuralNetConfiguration.setL2ByParam(addPrefixes2);
        neuralNetConfiguration.setVariables(addPrefixes3);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : init.entrySet()) {
            linkedHashMap.put(FORWARD_PREFIX + entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, INDArray> entry2 : init2.entrySet()) {
            linkedHashMap.put("b" + entry2.getKey(), entry2.getValue());
        }
        return linkedHashMap;
    }

    private <T> Map<String, T> addPrefixes(Map<String, T> map, Map<String, T> map2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, T> entry : map.entrySet()) {
            linkedHashMap.put(FORWARD_PREFIX + entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, T> entry2 : map2.entrySet()) {
            linkedHashMap.put("b" + entry2.getKey(), entry2.getValue());
        }
        return linkedHashMap;
    }

    private List<String> addPrefixes(List<String> list, List<String> list2) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(FORWARD_PREFIX + it.next());
        }
        Iterator<String> it2 = list2.iterator();
        while (it2.hasNext()) {
            arrayList.add("b" + it2.next());
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        int length = iNDArray.length() / 2;
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, length)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(length, 2 * length)});
        Map<String, INDArray> gradientsFromFlattened = this.underlying.initializer().getGradientsFromFlattened(neuralNetConfiguration, iNDArray2);
        Map<String, INDArray> gradientsFromFlattened2 = this.underlying.initializer().getGradientsFromFlattened(neuralNetConfiguration, iNDArray3);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : gradientsFromFlattened.entrySet()) {
            linkedHashMap.put(FORWARD_PREFIX + entry.getKey(), entry.getValue());
        }
        for (Map.Entry<String, INDArray> entry2 : gradientsFromFlattened2.entrySet()) {
            linkedHashMap.put("b" + entry2.getKey(), entry2.getValue());
        }
        return linkedHashMap;
    }

    private BaseRecurrentLayer underlying(Layer layer) {
        return (BaseRecurrentLayer) ((Bidirectional) layer).getFwd();
    }

    private List<String> withPrefixes(List<String> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(FORWARD_PREFIX + it.next());
        }
        Iterator<String> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add("b" + it2.next());
        }
        return arrayList;
    }
}
