package org.apache.beam.sdk.transforms;

import java.util.Iterator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.io.range.OffsetRangeTracker;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.TimeDomain;
import org.apache.beam.sdk.util.Timer;
import org.apache.beam.sdk.util.TimerSpec;
import org.apache.beam.sdk.util.TimerSpecs;
import org.apache.beam.sdk.util.state.BagState;
import org.apache.beam.sdk.util.state.CombiningState;
import org.apache.beam.sdk.util.state.StateSpec;
import org.apache.beam.sdk.util.state.StateSpecs;
import org.apache.beam.sdk.util.state.ValueState;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdks.java.core.repackaged.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.sdks.java.core.repackaged.com.google.common.base.Preconditions;
import org.apache.beam.sdks.java.core.repackaged.com.google.common.collect.Iterables;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/transforms/GroupIntoBatches.class */
public class GroupIntoBatches<K, InputT> extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> {
    private final long batchSize;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/sdk/transforms/GroupIntoBatches$GroupIntoBatchesDoFn.class */
    public static class GroupIntoBatchesDoFn<K, InputT> extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> {
        private static final Logger LOGGER = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class);
        private static final String END_OF_WINDOW_ID = "endOFWindow";
        private static final String BATCH_ID = "batch";
        private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch";
        private static final String KEY_ID = "key";
        private final long batchSize;
        private final Duration allowedLateness;

        @DoFn.StateId(BATCH_ID)
        private final StateSpec<Object, BagState<InputT>> batchSpec;

        @DoFn.StateId(KEY_ID)
        private final StateSpec<Object, ValueState<K>> keySpec;
        private final long prefetchFrequency;

        @DoFn.TimerId(END_OF_WINDOW_ID)
        private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @DoFn.StateId(NUM_ELEMENTS_IN_BATCH_ID)
        private final StateSpec<Object, CombiningState<Long, Long, Long>> numElementsInBatchSpec = StateSpecs.combining(VarLongCoder.of(), new Combine.CombineFn<Long, Long, Long>() { // from class: org.apache.beam.sdk.transforms.GroupIntoBatches.GroupIntoBatchesDoFn.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.beam.sdk.transforms.Combine.CombineFn
            public Long createAccumulator() {
                return 0L;
            }

            @Override // org.apache.beam.sdk.transforms.Combine.CombineFn
            public Long addInput(Long l, Long l2) {
                return Long.valueOf(l.longValue() + l2.longValue());
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.beam.sdk.transforms.Combine.CombineFn
            public Long mergeAccumulators(Iterable<Long> iterable) {
                long j = 0;
                Iterator<Long> it = iterable.iterator();
                while (it.hasNext()) {
                    j += it.next().longValue();
                }
                return Long.valueOf(j);
            }

            @Override // org.apache.beam.sdk.transforms.Combine.CombineFn
            public Long extractOutput(Long l) {
                return l;
            }
        });

        GroupIntoBatchesDoFn(long j, Duration duration, Coder<K> coder, Coder<InputT> coder2) {
            this.batchSize = j;
            this.allowedLateness = duration;
            this.batchSpec = StateSpecs.bag(coder2);
            this.keySpec = StateSpecs.value(coder);
            this.prefetchFrequency = j / 5 <= 1 ? OffsetRangeTracker.OFFSET_INFINITY : j / 5;
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.TimerId("endOFWindow") Timer timer, @DoFn.StateId("batch") BagState<InputT> bagState, @DoFn.StateId("numElementsInBatch") CombiningState<Long, Long, Long> combiningState, @DoFn.StateId("key") ValueState<K> valueState, DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>>.ProcessContext processContext, BoundedWindow boundedWindow) {
            Instant plus = boundedWindow.maxTimestamp().plus(this.allowedLateness);
            LOGGER.debug("*** SET TIMER *** to point in time {} for window {}", plus.toString(), boundedWindow.toString());
            timer.set(plus);
            valueState.write(processContext.element().getKey());
            bagState.add(processContext.element().getValue());
            LOGGER.debug("*** BATCH *** Add element for window {} ", boundedWindow.toString());
            combiningState.add(1L);
            Long read = combiningState.read();
            if (read.longValue() % this.prefetchFrequency == 0) {
                bagState.readLater();
            }
            if (read.longValue() >= this.batchSize) {
                LOGGER.debug("*** END OF BATCH *** for window {}", boundedWindow.toString());
                flushBatch(processContext, valueState, bagState, combiningState);
            }
        }

        @DoFn.OnTimer(END_OF_WINDOW_ID)
        public void onTimerCallback(DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>>.OnTimerContext onTimerContext, @DoFn.StateId("key") ValueState<K> valueState, @DoFn.StateId("batch") BagState<InputT> bagState, @DoFn.StateId("numElementsInBatch") CombiningState<Long, Long, Long> combiningState, BoundedWindow boundedWindow) {
            LOGGER.debug("*** END OF WINDOW *** for timer timestamp {} in windows {}", onTimerContext.timestamp(), boundedWindow.toString());
            flushBatch(onTimerContext, valueState, bagState, combiningState);
        }

        private void flushBatch(DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>>.Context context, ValueState<K> valueState, BagState<InputT> bagState, CombiningState<Long, Long, Long> combiningState) {
            Iterable iterable = (Iterable) bagState.read();
            if (!Iterables.isEmpty(iterable)) {
                context.output(KV.of(valueState.read(), iterable));
            }
            bagState.clear();
            LOGGER.debug("*** BATCH *** clear");
            combiningState.clear();
        }
    }

    private GroupIntoBatches(long j) {
        this.batchSize = j;
    }

    public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long j) {
        return new GroupIntoBatches<>(j);
    }

    @Override // org.apache.beam.sdk.transforms.PTransform
    public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> pCollection) {
        Duration allowedLateness = pCollection.getWindowingStrategy().getAllowedLateness();
        Preconditions.checkArgument(pCollection.getCoder() instanceof KvCoder, "coder specified in the input PCollection is not a KvCoder");
        KvCoder kvCoder = (KvCoder) pCollection.getCoder();
        return (PCollection) pCollection.apply(ParDo.of(new GroupIntoBatchesDoFn(this.batchSize, allowedLateness, kvCoder.getCoderArguments().get(0), kvCoder.getCoderArguments().get(1))));
    }
}
