package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;

/* loaded from: input_file:META-INF/bundled-dependencies/tokenizers-0.22.1.jar:ai/djl/huggingface/translator/FillMaskBatchTranslator.class */
public class FillMaskBatchTranslator implements NoBatchifyTranslator<String[], Classifications[]> {
    private HuggingFaceTokenizer tokenizer;
    private String maskToken;
    private long maskTokenId;
    private int topK;
    private Batchifier batchifier;

    /* JADX INFO: Access modifiers changed from: package-private */
    public FillMaskBatchTranslator(HuggingFaceTokenizer huggingFaceTokenizer, String str, int i, Batchifier batchifier) {
        this.tokenizer = huggingFaceTokenizer;
        this.maskToken = str;
        this.topK = i;
        this.batchifier = batchifier;
        this.maskTokenId = huggingFaceTokenizer.encode(str, false).getIds()[0];
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, String[] strArr) throws TranslateException {
        NDManager nDManager = translatorContext.getNDManager();
        Encoding[] batchEncode = this.tokenizer.batchEncode(strArr);
        NDList[] nDListArr = new NDList[batchEncode.length];
        int[] iArr = new int[batchEncode.length];
        translatorContext.setAttachment("maskIndices", iArr);
        for (int i = 0; i < batchEncode.length; i++) {
            iArr[i] = FillMaskTranslator.getMaskIndex(batchEncode[i].getIds(), this.maskToken, this.maskTokenId);
            nDListArr[i] = batchEncode[i].toNDList(nDManager, false);
        }
        return this.batchifier.batchify(nDListArr);
    }

    @Override // ai.djl.translate.PostProcessor
    public Classifications[] processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDList[] unbatchify = this.batchifier.unbatchify(nDList);
        int[] iArr = (int[]) translatorContext.getAttachment("maskIndices");
        Classifications[] classificationsArr = new Classifications[iArr.length];
        for (int i = 0; i < unbatchify.length; i++) {
            classificationsArr[i] = FillMaskTranslator.toClassifications(this.tokenizer, unbatchify[i], iArr[i], this.topK);
        }
        return classificationsArr;
    }
}
