package com.linkedin.venice.compression;

import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdCompressCtx;
import com.github.luben.zstd.ZstdDecompressCtx;
import com.github.luben.zstd.ZstdDictCompress;
import com.github.luben.zstd.ZstdDictDecompress;
import com.github.luben.zstd.ZstdDictTrainer;
import com.github.luben.zstd.ZstdException;
import com.github.luben.zstd.ZstdInputStream;
import com.linkedin.venice.compression.protocol.FakeCompressingSchema;
import com.linkedin.venice.serializer.AvroSerializer;
import com.linkedin.venice.utils.concurrent.CloseableThreadLocal;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.avro.generic.GenericData;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:com/linkedin/venice/compression/ZstdWithDictCompressor.class */
public class ZstdWithDictCompressor extends VeniceCompressor {
    private final CloseableThreadLocal<ZstdCompressCtx> compressor;
    private final CloseableThreadLocal<ZstdDecompressCtx> decompressor;
    private final ZstdDictCompress dictCompress;
    private final ZstdDictDecompress dictDecompress;
    private final byte[] dictionary;
    private final int level;

    public ZstdWithDictCompressor(byte[] bArr, int i) {
        super(CompressionStrategy.ZSTD_WITH_DICT);
        this.dictionary = bArr;
        this.level = i;
        this.dictCompress = new ZstdDictCompress(bArr, i);
        this.dictDecompress = new ZstdDictDecompress(bArr);
        this.compressor = new CloseableThreadLocal<>(() -> {
            return new ZstdCompressCtx().loadDict(this.dictCompress).setLevel(i);
        });
        this.decompressor = new CloseableThreadLocal<>(() -> {
            return new ZstdDecompressCtx().loadDict(this.dictDecompress);
        });
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor
    public byte[] compress(byte[] bArr) {
        return this.compressor.get().compress(bArr);
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor
    public ByteBuffer compress(ByteBuffer byteBuffer, int i) throws IOException {
        long compressBound = Zstd.compressBound(byteBuffer.remaining());
        if (compressBound + i > 2147483647L) {
            throw new ZstdException(Zstd.errGeneric(), "Max output size is greater than Integer.MAX_VALUE");
        }
        int i2 = ((int) compressBound) + i;
        if (byteBuffer.hasArray()) {
            byte[] bArr = new byte[i2];
            return ByteBuffer.wrap(bArr, i, this.compressor.get().compressByteArray(bArr, i, (int) compressBound, byteBuffer.array(), byteBuffer.position(), byteBuffer.remaining()));
        }
        if (!byteBuffer.isDirect()) {
            throw new IllegalArgumentException("The passed in ByteBuffer must be either direct or be backed by an array!");
        }
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(i2);
        allocateDirect.position(i);
        byteBuffer.mark();
        int compress = this.compressor.get().compress(allocateDirect, byteBuffer);
        allocateDirect.position(i);
        allocateDirect.limit(i + compress);
        byteBuffer.reset();
        return allocateDirect;
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor
    public ByteBuffer decompress(ByteBuffer byteBuffer) throws IOException {
        if (!byteBuffer.hasRemaining()) {
            return byteBuffer;
        }
        if (byteBuffer.hasArray()) {
            return decompress(byteBuffer.array(), byteBuffer.position(), byteBuffer.remaining());
        }
        if (!byteBuffer.isDirect()) {
            throw new IllegalArgumentException("The passed in ByteBuffer must be either direct or be backed by an array!");
        }
        int validateExpectedDecompressedSize = validateExpectedDecompressedSize(Zstd.decompressedSize(byteBuffer));
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(validateExpectedDecompressedSize);
        int decompress = this.decompressor.get().decompress(allocateDirect, byteBuffer);
        allocateDirect.position(0);
        validateActualDecompressedSize(decompress, validateExpectedDecompressedSize);
        return allocateDirect;
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor
    public ByteBuffer decompress(byte[] bArr, int i, int i2) throws IOException {
        int validateExpectedDecompressedSize = validateExpectedDecompressedSize(Zstd.decompressedSize(bArr, i, i2));
        ByteBuffer allocate = ByteBuffer.allocate(validateExpectedDecompressedSize);
        validateActualDecompressedSize(this.decompressor.get().decompressByteArray(allocate.array(), allocate.position(), allocate.remaining(), bArr, i, i2), validateExpectedDecompressedSize);
        allocate.position(0);
        return allocate;
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor
    public InputStream decompress(InputStream inputStream) throws IOException {
        return new ZstdInputStream(inputStream).setDict(this.dictDecompress);
    }

    @Override // com.linkedin.venice.compression.VeniceCompressor, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.compressor.close();
        this.decompressor.close();
        IOUtils.closeQuietly(this.dictCompress);
        IOUtils.closeQuietly(this.dictDecompress);
    }

    private int validateExpectedDecompressedSize(long j) {
        if (j == 0) {
            throw new IllegalStateException("The size of the compressed payload cannot be known.");
        }
        if (j > 2147483647L) {
            throw new IllegalStateException("The size of the compressed payload is > 2147483647");
        }
        return (int) j;
    }

    private void validateActualDecompressedSize(int i, int i2) {
        if (i != i2) {
            throw new IllegalStateException("The decompressed payload size (" + i + ") is not as expected (" + i2 + ").");
        }
    }

    public static byte[] buildDictionaryOnSyntheticAvroData() {
        AvroSerializer avroSerializer = new AvroSerializer(FakeCompressingSchema.getClassSchema());
        ArrayList arrayList = new ArrayList(50);
        for (int i = 0; i < 50; i++) {
            GenericData.Record record = new GenericData.Record(FakeCompressingSchema.getClassSchema());
            record.put("id", Integer.valueOf(i));
            record.put("name", i + "_name");
            arrayList.add(i, avroSerializer.serialize(record));
        }
        ZstdDictTrainer zstdDictTrainer = new ZstdDictTrainer(209715200, 102400);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            zstdDictTrainer.addSample((byte[]) it.next());
        }
        return zstdDictTrainer.trainSamples();
    }

    public int hashCode() {
        return super.hashCode();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return obj != null && (obj instanceof ZstdWithDictCompressor) && Arrays.equals(this.dictionary, ((ZstdWithDictCompressor) obj).dictionary) && this.level == ((ZstdWithDictCompressor) obj).level;
    }
}
