package org.apache.beam.runners.spark.translation;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.apache.beam.runners.spark.aggregators.AggAccumParam;
import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AggregatorValues;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Min;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaSparkContext;

/* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkRuntimeContext.class */
public class SparkRuntimeContext implements Serializable {
    private final Accumulator<NamedAggregators> accum;
    private final String serializedPipelineOptions;
    private final Map<String, Aggregator<?, ?>> aggregators = new HashMap();
    private transient CoderRegistry coderRegistry;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkRuntimeContext$SparkAggregator.class */
    public static class SparkAggregator<InputT, OutputT> implements Aggregator<InputT, OutputT>, Serializable {
        private final String name;
        private final NamedAggregators.State<InputT, ?, OutputT> state;

        SparkAggregator(String str, NamedAggregators.State<InputT, ?, OutputT> state) {
            this.name = str;
            this.state = state;
        }

        public String getName() {
            return this.name;
        }

        public void addValue(InputT inputt) {
            this.state.update(inputt);
        }

        public Combine.CombineFn<InputT, ?, OutputT> getCombineFn() {
            return this.state.getCombineFn();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparkRuntimeContext(JavaSparkContext javaSparkContext, Pipeline pipeline) {
        this.accum = javaSparkContext.accumulator(new NamedAggregators(), new AggAccumParam());
        this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions());
    }

    private static String serializePipelineOptions(PipelineOptions pipelineOptions) {
        try {
            return new ObjectMapper().writeValueAsString(pipelineOptions);
        } catch (JsonProcessingException e) {
            throw new IllegalStateException("Failed to serialize the pipeline options.", e);
        }
    }

    private static PipelineOptions deserializePipelineOptions(String str) {
        try {
            return (PipelineOptions) new ObjectMapper().readValue(str, PipelineOptions.class);
        } catch (IOException e) {
            throw new IllegalStateException("Failed to deserialize the pipeline options.", e);
        }
    }

    public <T> T getAggregatorValue(String str, Class<T> cls) {
        return (T) ((NamedAggregators) this.accum.value()).getValue(str, cls);
    }

    public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) {
        final Object aggregatorValue = getAggregatorValue(aggregator.getName(), aggregator.getCombineFn().getOutputType().getRawType());
        return new AggregatorValues<T>() { // from class: org.apache.beam.runners.spark.translation.SparkRuntimeContext.1
            public Collection<T> getValues() {
                return ImmutableList.of(aggregatorValue);
            }

            public Map<String, T> getValuesAtSteps() {
                throw new UnsupportedOperationException("getValuesAtSteps is not supported.");
            }
        };
    }

    public synchronized PipelineOptions getPipelineOptions() {
        return deserializePipelineOptions(this.serializedPipelineOptions);
    }

    public synchronized <InputT, InterT, OutputT> Aggregator<InputT, OutputT> createAggregator(String str, Combine.CombineFn<? super InputT, InterT, OutputT> combineFn) {
        Aggregator<?, ?> aggregator = this.aggregators.get(str);
        if (aggregator == null) {
            NamedAggregators.CombineFunctionState combineFunctionState = new NamedAggregators.CombineFunctionState(combineFn, getCoder(combineFn), this);
            this.accum.add(new NamedAggregators(str, combineFunctionState));
            aggregator = new SparkAggregator(str, combineFunctionState);
            this.aggregators.put(str, aggregator);
        }
        return (Aggregator<InputT, OutputT>) aggregator;
    }

    public CoderRegistry getCoderRegistry() {
        if (this.coderRegistry == null) {
            this.coderRegistry = new CoderRegistry();
            this.coderRegistry.registerStandardCoders();
        }
        return this.coderRegistry;
    }

    private Coder<?> getCoder(Combine.CombineFn<?, ?, ?> combineFn) {
        try {
            if (combineFn.getClass() == Sum.SumIntegerFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
            }
            if (combineFn.getClass() == Sum.SumLongFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
            }
            if (combineFn.getClass() == Sum.SumDoubleFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
            }
            if (combineFn.getClass() == Min.MinIntegerFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
            }
            if (combineFn.getClass() == Min.MinLongFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
            }
            if (combineFn.getClass() == Min.MinDoubleFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
            }
            if (combineFn.getClass() == Max.MaxIntegerFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Integer.class));
            }
            if (combineFn.getClass() == Max.MaxLongFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Long.class));
            }
            if (combineFn.getClass() == Max.MaxDoubleFn.class) {
                return getCoderRegistry().getDefaultCoder(TypeDescriptor.of(Double.class));
            }
            throw new IllegalArgumentException("unsupported combiner in Aggregator: " + combineFn.getClass().getName());
        } catch (CannotProvideCoderException e) {
            throw new IllegalStateException("Could not determine default coder for combiner", e);
        }
    }
}
