package org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Intersection;
import org.apache.datasketches.theta.SetOperation;
import org.apache.datasketches.theta.SetOperationBuilder;
import org.apache.datasketches.theta.Sketch;
import org.apache.datasketches.theta.Union;
import org.apache.pinot.$internal.com.google.common.base.Preconditions;
import org.apache.pinot.$internal.org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.pinot.$internal.org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.$internal.org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
import org.apache.pinot.$internal.org.apache.pinot.core.operator.filter.predicate.PredicateEvaluatorProvider;
import org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.ThetaSketchParams;
import org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.$internal.org.apache.pinot.core.query.request.context.ExpressionContext;
import org.apache.pinot.$internal.org.apache.pinot.core.query.request.context.FilterContext;
import org.apache.pinot.$internal.org.apache.pinot.core.query.request.context.predicate.Predicate;
import org.apache.pinot.$internal.org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.sql.parsers.CalciteSqlParser;

/* loaded from: input_file:org/apache/pinot/$internal/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.class */
public class DistinctCountThetaSketchAggregationFunction implements AggregationFunction<Map<String, Sketch>, Long> {
    private final ExpressionContext _thetaSketchColumn;
    private final ThetaSketchParams _thetaSketchParams;
    private final SetOperationBuilder _setOperationBuilder;
    private final List<ExpressionContext> _inputExpressions;
    private final FilterContext _postAggregationExpression;
    private final Map<Predicate, PredicateInfo> _predicateInfoMap;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pinot/$internal/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction$PredicateInfo.class */
    public static class PredicateInfo {
        final Predicate _predicate;
        final String _stringPredicate;
        PredicateEvaluator _predicateEvaluator = null;

        PredicateInfo(Predicate predicate) {
            this._predicate = predicate;
            this._stringPredicate = predicate.toString();
        }

        Predicate getPredicate() {
            return this._predicate;
        }

        String getStringPredicate() {
            return this._stringPredicate;
        }

        PredicateEvaluator getPredicateEvaluator(FieldSpec.DataType dataType) {
            if (this._predicateEvaluator == null) {
                this._predicateEvaluator = PredicateEvaluatorProvider.getPredicateEvaluator(this._predicate, null, dataType);
            }
            return this._predicateEvaluator;
        }
    }

    public DistinctCountThetaSketchAggregationFunction(List<ExpressionContext> list) throws SqlParseException {
        int size = list.size();
        Preconditions.checkArgument(size >= 3, "DistinctCountThetaSketch expects at least three arguments (theta-sketch column, parameters, post-aggregation expression), got: ", size);
        this._thetaSketchColumn = list.get(0);
        Preconditions.checkArgument(this._thetaSketchColumn.getType() == ExpressionContext.Type.IDENTIFIER, "First argument of DistinctCountThetaSketch must be identifier (theta-sketch column)");
        ExpressionContext expressionContext = list.get(1);
        Preconditions.checkArgument(expressionContext.getType() == ExpressionContext.Type.LITERAL, "Second argument of DistinctCountThetaSketch must be literal (parameters)");
        this._thetaSketchParams = ThetaSketchParams.fromString(expressionContext.getLiteral());
        this._setOperationBuilder = getSetOperationBuilder();
        this._inputExpressions = new ArrayList();
        this._inputExpressions.add(this._thetaSketchColumn);
        ExpressionContext expressionContext2 = list.get(size - 1);
        Preconditions.checkArgument(expressionContext.getType() == ExpressionContext.Type.LITERAL, "Last argument of DistinctCountThetaSketch must be literal (post-aggregation expression)");
        this._postAggregationExpression = QueryContextConverterUtils.getFilter(CalciteSqlParser.compileToExpression(expressionContext2.getLiteral()));
        this._predicateInfoMap = new HashMap();
        if (size > 3) {
            for (int i = 2; i < size - 1; i++) {
                ExpressionContext expressionContext3 = list.get(i);
                Preconditions.checkArgument(expressionContext3.getType() == ExpressionContext.Type.LITERAL, "Third to second last argument of DistinctCountThetaSketch must be literal (predicate expression)");
                Predicate predicate = getPredicate(expressionContext3.getLiteral());
                this._inputExpressions.add(predicate.getLhs());
                this._predicateInfoMap.put(predicate, new PredicateInfo(predicate));
            }
            return;
        }
        Stack stack = new Stack();
        stack.push(this._postAggregationExpression);
        while (!stack.isEmpty()) {
            FilterContext filterContext = (FilterContext) stack.pop();
            if (filterContext.getType() == FilterContext.Type.PREDICATE) {
                Predicate predicate2 = filterContext.getPredicate();
                this._inputExpressions.add(predicate2.getLhs());
                this._predicateInfoMap.put(predicate2, new PredicateInfo(predicate2));
            } else {
                stack.addAll(filterContext.getChildren());
            }
        }
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public String getColumnName() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH.getName() + "_" + this._thetaSketchColumn;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public String getResultColumnName() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH.getName().toLowerCase() + DefaultExpressionEngine.DEFAULT_INDEX_START + this._thetaSketchColumn + DefaultExpressionEngine.DEFAULT_INDEX_END;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public List<ExpressionContext> getInputExpressions() {
        return this._inputExpressions;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void accept(AggregationFunctionVisitorBase aggregationFunctionVisitorBase) {
        aggregationFunctionVisitorBase.visit(this);
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> map) {
        Map<Predicate, Union> unionMap = getUnionMap(aggregationResultHolder);
        Sketch[] deserializeSketches = deserializeSketches(map.get(this._thetaSketchColumn).getBytesValuesSV(), i);
        for (PredicateInfo predicateInfo : this._predicateInfoMap.values()) {
            Predicate predicate = predicateInfo.getPredicate();
            BlockValSet blockValSet = map.get(predicate.getLhs());
            FieldSpec.DataType valueType = blockValSet.getValueType();
            PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
            Union union = unionMap.get(predicate);
            switch (valueType) {
                case INT:
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (predicateEvaluator.applySV(intValuesSV[i2])) {
                            union.update(deserializeSketches[i2]);
                        }
                    }
                    break;
                case LONG:
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (predicateEvaluator.applySV(longValuesSV[i3])) {
                            union.update(deserializeSketches[i3]);
                        }
                    }
                    break;
                case FLOAT:
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (predicateEvaluator.applySV(floatValuesSV[i4])) {
                            union.update(deserializeSketches[i4]);
                        }
                    }
                    break;
                case DOUBLE:
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (predicateEvaluator.applySV(doubleValuesSV[i5])) {
                            union.update(deserializeSketches[i5]);
                        }
                    }
                    break;
                case STRING:
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (predicateEvaluator.applySV(stringValuesSV[i6])) {
                            union.update(deserializeSketches[i6]);
                        }
                    }
                    break;
                case BYTES:
                    byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                    for (int i7 = 0; i7 < i; i7++) {
                        if (predicateEvaluator.applySV(bytesValuesSV[i7])) {
                            union.update(deserializeSketches[i7]);
                        }
                    }
                    break;
                default:
                    throw new IllegalStateException();
            }
        }
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        Sketch[] deserializeSketches = deserializeSketches(map.get(this._thetaSketchColumn).getBytesValuesSV(), i);
        for (PredicateInfo predicateInfo : this._predicateInfoMap.values()) {
            Predicate predicate = predicateInfo.getPredicate();
            BlockValSet blockValSet = map.get(predicate.getLhs());
            FieldSpec.DataType valueType = blockValSet.getValueType();
            PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
            switch (valueType) {
                case INT:
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (predicateEvaluator.applySV(intValuesSV[i2])) {
                            getUnionMap(groupByResultHolder, iArr[i2]).get(predicate).update(deserializeSketches[i2]);
                        }
                    }
                    break;
                case LONG:
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (predicateEvaluator.applySV(longValuesSV[i3])) {
                            getUnionMap(groupByResultHolder, iArr[i3]).get(predicate).update(deserializeSketches[i3]);
                        }
                    }
                    break;
                case FLOAT:
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (predicateEvaluator.applySV(floatValuesSV[i4])) {
                            getUnionMap(groupByResultHolder, iArr[i4]).get(predicate).update(deserializeSketches[i4]);
                        }
                    }
                    break;
                case DOUBLE:
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (predicateEvaluator.applySV(doubleValuesSV[i5])) {
                            getUnionMap(groupByResultHolder, iArr[i5]).get(predicate).update(deserializeSketches[i5]);
                        }
                    }
                    break;
                case STRING:
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (predicateEvaluator.applySV(stringValuesSV[i6])) {
                            getUnionMap(groupByResultHolder, iArr[i6]).get(predicate).update(deserializeSketches[i6]);
                        }
                    }
                    break;
                case BYTES:
                    byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                    for (int i7 = 0; i7 < i; i7++) {
                        if (predicateEvaluator.applySV(bytesValuesSV[i7])) {
                            getUnionMap(groupByResultHolder, iArr[i7]).get(predicate).update(deserializeSketches[i7]);
                        }
                    }
                    throw new IllegalStateException();
                default:
                    throw new IllegalStateException();
            }
        }
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupByMV(int i, int[][] iArr, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> map) {
        Sketch[] deserializeSketches = deserializeSketches(map.get(this._thetaSketchColumn).getBytesValuesSV(), i);
        for (PredicateInfo predicateInfo : this._predicateInfoMap.values()) {
            Predicate predicate = predicateInfo.getPredicate();
            BlockValSet blockValSet = map.get(predicate.getLhs());
            FieldSpec.DataType valueType = blockValSet.getValueType();
            PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
            switch (valueType) {
                case INT:
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (predicateEvaluator.applySV(intValuesSV[i2])) {
                            for (int i3 : iArr[i2]) {
                                getUnionMap(groupByResultHolder, i3).get(predicate).update(deserializeSketches[i2]);
                            }
                        }
                    }
                    break;
                case LONG:
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (predicateEvaluator.applySV(longValuesSV[i4])) {
                            for (int i5 : iArr[i4]) {
                                getUnionMap(groupByResultHolder, i5).get(predicate).update(deserializeSketches[i4]);
                            }
                        }
                    }
                    break;
                case FLOAT:
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (predicateEvaluator.applySV(floatValuesSV[i6])) {
                            for (int i7 : iArr[i6]) {
                                getUnionMap(groupByResultHolder, i7).get(predicate).update(deserializeSketches[i6]);
                            }
                        }
                    }
                    break;
                case DOUBLE:
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i8 = 0; i8 < i; i8++) {
                        if (predicateEvaluator.applySV(doubleValuesSV[i8])) {
                            for (int i9 : iArr[i8]) {
                                getUnionMap(groupByResultHolder, i9).get(predicate).update(deserializeSketches[i8]);
                            }
                        }
                    }
                    break;
                case STRING:
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i10 = 0; i10 < i; i10++) {
                        if (predicateEvaluator.applySV(stringValuesSV[i10])) {
                            for (int i11 : iArr[i10]) {
                                getUnionMap(groupByResultHolder, i11).get(predicate).update(deserializeSketches[i10]);
                            }
                        }
                    }
                    break;
                case BYTES:
                    byte[][] bytesValuesSV = blockValSet.getBytesValuesSV();
                    for (int i12 = 0; i12 < i; i12++) {
                        if (predicateEvaluator.applySV(bytesValuesSV[i12])) {
                            for (int i13 : iArr[i12]) {
                                getUnionMap(groupByResultHolder, i13).get(predicate).update(deserializeSketches[i12]);
                            }
                        }
                    }
                    break;
                default:
                    throw new IllegalStateException();
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Map map = (Map) aggregationResultHolder.getResult();
        if (map == null || map.isEmpty()) {
            return Collections.emptyMap();
        }
        HashMap hashMap = new HashMap();
        for (PredicateInfo predicateInfo : this._predicateInfoMap.values()) {
            hashMap.put(predicateInfo.getStringPredicate(), ((Union) map.get(predicateInfo.getPredicate())).getResult());
        }
        return hashMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        Map map = (Map) groupByResultHolder.getResult(i);
        if (map == null || map.isEmpty()) {
            return Collections.emptyMap();
        }
        HashMap hashMap = new HashMap();
        for (PredicateInfo predicateInfo : this._predicateInfoMap.values()) {
            hashMap.put(predicateInfo.getStringPredicate(), ((Union) map.get(predicateInfo.getPredicate())).getResult());
        }
        return hashMap;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> merge(Map<String, Sketch> map, Map<String, Sketch> map2) {
        if (map == null || map.isEmpty()) {
            return map2;
        }
        if (map2 == null || map2.isEmpty()) {
            return map;
        }
        Map<Predicate, Union> defaultUnionMap = getDefaultUnionMap();
        for (Map.Entry<String, Sketch> entry : map.entrySet()) {
            defaultUnionMap.get(getPredicate(entry.getKey())).update(entry.getValue());
        }
        for (Map.Entry<String, Sketch> entry2 : map2.entrySet()) {
            defaultUnionMap.get(getPredicate(entry2.getKey())).update(entry2.getValue());
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<Predicate, Union> entry3 : defaultUnionMap.entrySet()) {
            hashMap.put(entry3.getKey().toString(), entry3.getValue().getResult());
        }
        return hashMap;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public boolean isIntermediateResultComparable() {
        return false;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override // org.apache.pinot.$internal.org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Long extractFinalResult(Map<String, Sketch> map) {
        return Long.valueOf(Math.round(extractFinalSketch(map).getEstimate()));
    }

    private Predicate getPredicate(String str) {
        try {
            FilterContext filter = QueryContextConverterUtils.getFilter(CalciteSqlParser.compileToExpression(str));
            Preconditions.checkArgument(filter.getType() == FilterContext.Type.PREDICATE, "Invalid predicate string: %s", str);
            return filter.getPredicate();
        } catch (SqlParseException e) {
            throw new IllegalArgumentException("Invalid predicate string: " + str);
        }
    }

    private Map<Predicate, Union> getUnionMap(AggregationResultHolder aggregationResultHolder) {
        Map<Predicate, Union> map = (Map) aggregationResultHolder.getResult();
        if (map == null) {
            map = getDefaultUnionMap();
            aggregationResultHolder.setValue(map);
        }
        return map;
    }

    private Map<Predicate, Union> getUnionMap(GroupByResultHolder groupByResultHolder, int i) {
        Map<Predicate, Union> map = (Map) groupByResultHolder.getResult(i);
        if (map == null) {
            map = getDefaultUnionMap();
            groupByResultHolder.setValueForKey(i, map);
        }
        return map;
    }

    private Map<Predicate, Union> getDefaultUnionMap() {
        HashMap hashMap = new HashMap();
        Iterator<Predicate> it = this._predicateInfoMap.keySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), this._setOperationBuilder.buildUnion());
        }
        return hashMap;
    }

    private Sketch[] deserializeSketches(byte[][] bArr, int i) {
        Sketch[] sketchArr = new Sketch[i];
        for (int i2 = 0; i2 < i; i2++) {
            sketchArr[i2] = Sketch.wrap(Memory.wrap(bArr[i2]));
        }
        return sketchArr;
    }

    private Sketch evalPostAggregationExpression(FilterContext filterContext, Map<Predicate, Sketch> map) {
        switch (filterContext.getType()) {
            case AND:
                Intersection buildIntersection = this._setOperationBuilder.buildIntersection();
                Iterator<FilterContext> it = filterContext.getChildren().iterator();
                while (it.hasNext()) {
                    buildIntersection.update(evalPostAggregationExpression(it.next(), map));
                }
                return buildIntersection.getResult();
            case OR:
                Union buildUnion = this._setOperationBuilder.buildUnion();
                Iterator<FilterContext> it2 = filterContext.getChildren().iterator();
                while (it2.hasNext()) {
                    buildUnion.update(evalPostAggregationExpression(it2.next(), map));
                }
                return buildUnion.getResult();
            case PREDICATE:
                return map.get(filterContext.getPredicate());
            default:
                throw new IllegalStateException();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Sketch extractFinalSketch(Map<String, Sketch> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Sketch> entry : map.entrySet()) {
            hashMap.put(getPredicate(entry.getKey()), entry.getValue());
        }
        return evalPostAggregationExpression(this._postAggregationExpression, hashMap);
    }

    private SetOperationBuilder getSetOperationBuilder() {
        return this._thetaSketchParams == null ? SetOperation.builder() : SetOperation.builder().setNominalEntries(this._thetaSketchParams.getNominalEntries());
    }
}
