package io.trino.operator.aggregation.listagg;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionKind;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.FunctionNullability;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationFunctionAdapter;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.spi.type.VarcharType;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.Optional;

/* loaded from: input_file:io/trino/operator/aggregation/listagg/ListaggAggregationFunction.class */
public class ListaggAggregationFunction extends SqlAggregationFunction {
    public static final String NAME = "listagg";
    private static final int MAX_OUTPUT_LENGTH = 1048576;
    private static final int MAX_OVERFLOW_FILLER_LENGTH = 65536;
    public static final ListaggAggregationFunction LISTAGG = new ListaggAggregationFunction();
    private static final MethodHandle INPUT_FUNCTION = Reflection.methodHandle(ListaggAggregationFunction.class, "input", Type.class, ListaggAggregationState.class, Block.class, Slice.class, Boolean.TYPE, Slice.class, Boolean.TYPE, Integer.TYPE);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(ListaggAggregationFunction.class, "combine", Type.class, ListaggAggregationState.class, ListaggAggregationState.class);
    private static final MethodHandle OUTPUT_FUNCTION = Reflection.methodHandle(ListaggAggregationFunction.class, "output", Type.class, ListaggAggregationState.class, BlockBuilder.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/aggregation/listagg/ListaggAggregationFunction$OutputContext.class */
    public static class OutputContext {
        long outputLength;
        int emittedEntryCount;
        boolean overflow;

        private OutputContext() {
        }
    }

    private ListaggAggregationFunction() {
        super(new FunctionMetadata(new Signature(NAME, ImmutableList.of(), ImmutableList.of(), VarcharType.VARCHAR.getTypeSignature(), ImmutableList.of(new TypeSignature("varchar", new TypeSignatureParameter[]{TypeSignatureParameter.typeVariable("v")}), new TypeSignature("varchar", new TypeSignatureParameter[]{TypeSignatureParameter.typeVariable("d")}), BooleanType.BOOLEAN.getTypeSignature(), new TypeSignature("varchar", new TypeSignatureParameter[]{TypeSignatureParameter.typeVariable("f")}), BooleanType.BOOLEAN.getTypeSignature()), false), new FunctionNullability(true, ImmutableList.of(true, false, false, false, false)), false, true, "concatenates the input values with the specified separator", FunctionKind.AGGREGATE), new AggregationFunctionMetadata(true, VarcharType.VARCHAR.getTypeSignature(), BooleanType.BOOLEAN.getTypeSignature(), VarcharType.VARCHAR.getTypeSignature(), BooleanType.BOOLEAN.getTypeSignature(), TypeSignature.arrayType(VarcharType.VARCHAR.getTypeSignature())));
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public AggregationMetadata specialize(BoundSignature boundSignature) {
        VarcharType varcharType = VarcharType.VARCHAR;
        ListaggAggregationStateSerializer listaggAggregationStateSerializer = new ListaggAggregationStateSerializer(varcharType);
        ListaggAggregationStateFactory listaggAggregationStateFactory = new ListaggAggregationStateFactory(varcharType);
        return new AggregationMetadata(AggregationFunctionAdapter.normalizeInputMethod(INPUT_FUNCTION.bindTo(varcharType), boundSignature, AggregationFunctionAdapter.AggregationParameterKind.STATE, AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX), Optional.empty(), Optional.of(COMBINE_FUNCTION.bindTo(varcharType)), OUTPUT_FUNCTION.bindTo(varcharType), ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(ListaggAggregationState.class, listaggAggregationStateSerializer, listaggAggregationStateFactory)));
    }

    public static void input(Type type, ListaggAggregationState listaggAggregationState, Block block, Slice slice, boolean z, Slice slice2, boolean z2, int i) {
        if (listaggAggregationState.isEmpty()) {
            if (slice2.length() > MAX_OVERFLOW_FILLER_LENGTH) {
                throw new TrinoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, String.format("Overflow filler length %d exceeds maximum length %d", Integer.valueOf(slice2.length()), Integer.valueOf(MAX_OVERFLOW_FILLER_LENGTH)));
            }
            listaggAggregationState.setSeparator(slice);
            listaggAggregationState.setOverflowError(z);
            listaggAggregationState.setOverflowFiller(slice2);
            listaggAggregationState.setShowOverflowEntryCount(z2);
        }
        if (block.isNull(i)) {
            return;
        }
        listaggAggregationState.add(block, i);
    }

    public static void combine(Type type, ListaggAggregationState listaggAggregationState, ListaggAggregationState listaggAggregationState2) {
        if (listaggAggregationState.getSeparator() == null) {
            listaggAggregationState.setSeparator(listaggAggregationState2.getSeparator());
            listaggAggregationState.setOverflowError(listaggAggregationState2.isOverflowError());
            listaggAggregationState.setOverflowFiller(listaggAggregationState2.getOverflowFiller());
            listaggAggregationState.setShowOverflowEntryCount(listaggAggregationState2.showOverflowEntryCount());
        }
        listaggAggregationState.merge(listaggAggregationState2);
    }

    public static void output(Type type, ListaggAggregationState listaggAggregationState, BlockBuilder blockBuilder) {
        if (listaggAggregationState.isEmpty()) {
            blockBuilder.appendNull();
        } else {
            outputState(listaggAggregationState, blockBuilder, MAX_OUTPUT_LENGTH);
        }
    }

    @VisibleForTesting
    protected static void outputState(ListaggAggregationState listaggAggregationState, BlockBuilder blockBuilder, int i) {
        Slice separator = listaggAggregationState.getSeparator();
        int length = separator.length();
        OutputContext outputContext = new OutputContext();
        listaggAggregationState.forEach((block, i2) -> {
            int sliceLength = block.getSliceLength(i2);
            if (outputContext.outputLength + sliceLength + (outputContext.emittedEntryCount > 0 ? length : 0) > i) {
                outputContext.overflow = true;
                return false;
            }
            if (outputContext.emittedEntryCount > 0) {
                blockBuilder.writeBytes(separator, 0, length);
                outputContext.outputLength += length;
            }
            block.writeBytesTo(i2, 0, sliceLength, blockBuilder);
            outputContext.outputLength += sliceLength;
            outputContext.emittedEntryCount++;
            return true;
        });
        if (outputContext.overflow) {
            if (listaggAggregationState.isOverflowError()) {
                throw new TrinoException(StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT, String.format("Concatenated string has the length in bytes larger than the maximum output length %d", Integer.valueOf(i)));
            }
            if (outputContext.emittedEntryCount > 0) {
                blockBuilder.writeBytes(separator, 0, length);
            }
            blockBuilder.writeBytes(listaggAggregationState.getOverflowFiller(), 0, listaggAggregationState.getOverflowFiller().length());
            if (listaggAggregationState.showOverflowEntryCount()) {
                blockBuilder.writeBytes(Slices.utf8Slice("("), 0, 1);
                Slice utf8Slice = Slices.utf8Slice(Integer.toString(listaggAggregationState.getEntryCount() - outputContext.emittedEntryCount));
                blockBuilder.writeBytes(utf8Slice, 0, utf8Slice.length());
                blockBuilder.writeBytes(Slices.utf8Slice(")"), 0, 1);
            }
        }
        blockBuilder.closeEntry();
    }
}
