package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/training/loss/TabNetClassificationLoss.class */
public final class TabNetClassificationLoss extends Loss {
    public TabNetClassificationLoss() {
        this("TabNetClassificationLoss");
    }

    public TabNetClassificationLoss(String str) {
        super(str);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        return Loss.softmaxCrossEntropyLoss().evaluate(nDList, new NDList(nDList2.get(0))).add(nDList2.get(1));
    }
}
