package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/training/loss/MaskedSoftmaxCrossEntropyLoss.class */
public class MaskedSoftmaxCrossEntropyLoss extends Loss {
    private float weight;
    private int classAxis;
    private boolean sparseLabel;
    private boolean fromLogit;

    public MaskedSoftmaxCrossEntropyLoss() {
        this("MaskedSoftmaxCrossEntropyLoss");
    }

    public MaskedSoftmaxCrossEntropyLoss(String str) {
        this(str, 1.0f, -1, true, false);
    }

    public MaskedSoftmaxCrossEntropyLoss(String str, float f, int i, boolean z, boolean z2) {
        super(str);
        this.weight = f;
        this.classAxis = i;
        this.sparseLabel = z;
        this.fromLogit = z2;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        NDArray sequenceMask = nDList.head().onesLike().expandDims(-1).sequenceMask(nDList.get(1));
        NDArray singletonOrThrow = nDList2.singletonOrThrow();
        if (!this.fromLogit) {
            singletonOrThrow = singletonOrThrow.logSoftmax(this.classAxis);
        }
        NDArray head = nDList.head();
        NDArray mul = (this.sparseLabel ? singletonOrThrow.get(new NDIndex().addAllDim(Math.floorMod(this.classAxis, singletonOrThrow.getShape().dimension())).addPickDim(head)).neg() : singletonOrThrow.mul(head.reshape(singletonOrThrow.getShape())).neg().sum(new int[]{this.classAxis}, true)).mul(sequenceMask);
        if (this.weight != 1.0f) {
            mul = mul.mul(Float.valueOf(this.weight));
        }
        return mul.mean(new int[]{1});
    }
}
