package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.stream.IntStream;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/nn/core/SparseMax.class */
public class SparseMax extends AbstractBlock {
    private static final Byte VERSION = (byte) 1;
    private int axis;
    private int topK;

    public SparseMax() {
        this(-1, 3);
    }

    public SparseMax(int i) {
        this(i, 3);
    }

    public SparseMax(int i, int i2) {
        super(VERSION.byteValue());
        this.axis = i;
        this.topK = i2;
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        if (this.axis != -1) {
            singletonOrThrow = singletonOrThrow.swapAxes(this.axis, -1);
        }
        NDArray type = singletonOrThrow.argSort(-1, false).toType(DataType.INT64, false);
        int size = (int) singletonOrThrow.size(singletonOrThrow.getShape().dimension() - 1);
        NDArray add = NDArrays.add((NDArray[]) IntStream.range(0, this.topK).mapToObj(i -> {
            return type.get("..., {}", Integer.valueOf(i)).oneHot(size);
        }).toArray(i2 -> {
            return new NDArray[i2];
        }));
        NDArray div = singletonOrThrow.exp().mul(add).div(singletonOrThrow.exp().mul(add).sum(new int[]{-1}, true).broadcast(singletonOrThrow.getShape()));
        if (this.axis != -1) {
            div = div.swapAxes(this.axis, -1);
        }
        return new NDList(div);
    }
}
