package com.datastax.oss.streaming.ai;

import com.datastax.oss.streaming.ai.jstl.JstlTypeConverter;
import com.datastax.oss.streaming.ai.model.ComputeField;
import com.datastax.oss.streaming.ai.model.ComputeFieldType;
import com.datastax.oss.streaming.ai.model.TransformSchemaType;
import com.datastax.oss.streaming.ai.util.AvroUtil;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;

/* loaded from: input_file:META-INF/bundled-dependencies/streaming-ai-3.1.5.jar:com/datastax/oss/streaming/ai/ComputeStep.class */
public class ComputeStep implements TransformStep {
    public static final long MILLIS_PER_DAY = TimeUnit.DAYS.toMillis(1);
    private final List<ComputeField> fields;
    private final Map<Schema, Schema> keySchemaCache = new ConcurrentHashMap();
    private final Map<Schema, Schema> valueSchemaCache = new ConcurrentHashMap();
    private final Map<ComputeFieldType, Schema> fieldTypeToAvroSchemaCache = new ConcurrentHashMap();

    /* loaded from: input_file:META-INF/bundled-dependencies/streaming-ai-3.1.5.jar:com/datastax/oss/streaming/ai/ComputeStep$ComputeStepBuilder.class */
    public static class ComputeStepBuilder {
        private boolean fields$set;
        private List<ComputeField> fields$value;

        ComputeStepBuilder() {
        }

        public ComputeStepBuilder fields(List<ComputeField> list) {
            this.fields$value = list;
            this.fields$set = true;
            return this;
        }

        public ComputeStep build() {
            List<ComputeField> list = this.fields$value;
            if (!this.fields$set) {
                list = ComputeStep.$default$fields();
            }
            return new ComputeStep(list);
        }

        public String toString() {
            return "ComputeStep.ComputeStepBuilder(fields$value=" + this.fields$value + ")";
        }
    }

    @Override // com.datastax.oss.streaming.ai.TransformStep
    public void process(TransformContext transformContext) {
        computePrimitiveField((List) this.fields.stream().filter(computeField -> {
            return "primitive".equals(computeField.getScope());
        }).collect(Collectors.toList()), transformContext);
        computeKeyFields((List) this.fields.stream().filter(computeField2 -> {
            return "key".equals(computeField2.getScope());
        }).collect(Collectors.toList()), transformContext);
        computeValueFields((List) this.fields.stream().filter(computeField3 -> {
            return "value".equals(computeField3.getScope());
        }).collect(Collectors.toList()), transformContext);
        computeHeaderFields((List) this.fields.stream().filter(computeField4 -> {
            return "header".equals(computeField4.getScope());
        }).collect(Collectors.toList()), transformContext);
        computeHeaderPropertiesFields((List) this.fields.stream().filter(computeField5 -> {
            return "header.properties".equals(computeField5.getScope());
        }).collect(Collectors.toList()), transformContext);
    }

    public void computeValueFields(List<ComputeField> list, TransformContext transformContext) {
        TransformSchemaType valueSchemaType = transformContext.getValueSchemaType();
        if (transformContext.getValueObject() instanceof Map) {
            getEvaluatedFields(list, transformContext).forEach((field, obj) -> {
                ((Map) transformContext.getValueObject()).put(field.name(), obj);
            });
        } else if (valueSchemaType == TransformSchemaType.AVRO || valueSchemaType == TransformSchemaType.JSON) {
            transformContext.addOrReplaceValueFields(getEvaluatedFields(list, transformContext), this.valueSchemaCache);
        }
    }

    public void computePrimitiveField(List<ComputeField> list, TransformContext transformContext) {
        list.stream().filter(computeField -> {
            return "key".equals(computeField.getName());
        }).findFirst().ifPresent(computeField2 -> {
            if (transformContext.getKeySchemaType() == null || !transformContext.getKeySchemaType().isPrimitive()) {
                return;
            }
            Object evaluate = computeField2.getEvaluator().evaluate(transformContext);
            TransformSchemaType primitiveSchema = computeField2.getType() != null ? getPrimitiveSchema(computeField2.getType()) : getPrimitiveSchema(evaluate);
            transformContext.setKeyObject(evaluate);
            transformContext.setKeySchemaType(primitiveSchema);
        });
        list.stream().filter(computeField3 -> {
            return "value".equals(computeField3.getName());
        }).findFirst().ifPresent(computeField4 -> {
            if (transformContext.getValueSchemaType().isPrimitive()) {
                Object evaluate = computeField4.getEvaluator().evaluate(transformContext);
                TransformSchemaType primitiveSchema = computeField4.getType() != null ? getPrimitiveSchema(computeField4.getType()) : getPrimitiveSchema(evaluate);
                transformContext.setValueObject(evaluate);
                transformContext.setValueSchemaType(primitiveSchema);
            }
        });
    }

    public void computeKeyFields(List<ComputeField> list, TransformContext transformContext) {
        Object keyObject = transformContext.getKeyObject();
        if (keyObject != null) {
            if (keyObject instanceof Map) {
                getEvaluatedFields(list, transformContext).forEach((field, obj) -> {
                    ((Map) keyObject).put(field.name(), obj);
                });
                return;
            }
            TransformSchemaType keySchemaType = transformContext.getKeySchemaType();
            if (keySchemaType == TransformSchemaType.AVRO || keySchemaType == TransformSchemaType.JSON) {
                transformContext.addOrReplaceKeyFields(getEvaluatedFields(list, transformContext), this.keySchemaCache);
            }
        }
    }

    public void computeHeaderFields(List<ComputeField> list, TransformContext transformContext) {
        list.forEach(computeField -> {
            String name = computeField.getName();
            boolean z = -1;
            switch (name.hashCode()) {
                case -1690741544:
                    if (name.equals("messageKey")) {
                        z = true;
                        break;
                    }
                    break;
                case 1195110849:
                    if (name.equals("destinationTopic")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    transformContext.setOutputTopic(validateAndGetString(computeField, transformContext));
                    return;
                case true:
                    transformContext.setKey(validateAndGetString(computeField, transformContext));
                    return;
                default:
                    throw new IllegalArgumentException("Invalid compute field name: " + computeField.getName());
            }
        });
    }

    public void computeHeaderPropertiesFields(List<ComputeField> list, TransformContext transformContext) {
        list.forEach(computeField -> {
            transformContext.setProperty(computeField.getName(), validateAndGetString(computeField, transformContext));
        });
    }

    private String validateAndGetString(ComputeField computeField, TransformContext transformContext) {
        Object evaluate = computeField.getEvaluator().evaluate(transformContext);
        if (evaluate instanceof String) {
            return (String) evaluate;
        }
        Object[] objArr = new Object[3];
        objArr[0] = computeField.getName();
        objArr[1] = evaluate == null ? "null" : evaluate.getClass().getSimpleName();
        objArr[2] = "String";
        throw new IllegalArgumentException(String.format("Invalid compute field type. Name: %s, Type: %s, Expected Type: %s", objArr));
    }

    private Map<Schema.Field, Object> getEvaluatedFields(List<ComputeField> list, TransformContext transformContext) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (ComputeField computeField : list) {
            Object evaluate = computeField.getEvaluator().evaluate(transformContext);
            Schema.Field createAvroField = createAvroField(computeField, computeField.getType() == null ? getFieldType(evaluate) : computeField.getType(), evaluate);
            linkedHashMap.put(createAvroField, getAvroValue(createAvroField.schema(), evaluate));
        }
        return linkedHashMap;
    }

    private Object getAvroValue(Schema schema, Object obj) {
        if (obj == null) {
            return null;
        }
        if (obj instanceof byte[]) {
            return ByteBuffer.wrap((byte[]) obj);
        }
        if ((obj instanceof Byte) || (obj instanceof Short)) {
            return Integer.valueOf(((Number) obj).intValue());
        }
        LogicalType logicalType = AvroUtil.getLogicalType(schema);
        if (logicalType == null) {
            return obj;
        }
        String name = logicalType.getName();
        boolean z = -1;
        switch (name.hashCode()) {
            case -752000698:
                if (name.equals("time-millis")) {
                    z = true;
                    break;
                }
                break;
            case 3076014:
                if (name.equals("date")) {
                    z = false;
                    break;
                }
                break;
            case 1542263633:
                if (name.equals("decimal")) {
                    z = 3;
                    break;
                }
                break;
            case 1922275037:
                if (name.equals("timestamp-millis")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return getAvroDate(obj, schema.getLogicalType());
            case true:
                return getAvroTimeMillis(obj, schema.getLogicalType());
            case true:
                return getAvroTimestampMillis(obj, schema.getLogicalType());
            case true:
                return getAvroDecimal(obj, schema.getLogicalType());
            default:
                throw new IllegalArgumentException(String.format("Invalid logical type %s for value %s", schema.getLogicalType(), obj));
        }
    }

    private Long getAvroTimestampMillis(Object obj, LogicalType logicalType) {
        validateLogicalType(obj, logicalType, Instant.class, Timestamp.class, LocalDateTime.class);
        return (Long) JstlTypeConverter.INSTANCE.coerceToType(obj, Long.class);
    }

    private Integer getAvroDate(Object obj, LogicalType logicalType) {
        validateLogicalType(obj, logicalType, Date.class, LocalDate.class);
        return Integer.valueOf(obj instanceof LocalDate ? Math.toIntExact(((LocalDate) obj).toEpochDay()) : Math.toIntExact(((Date) obj).getTime() / MILLIS_PER_DAY));
    }

    private Integer getAvroTimeMillis(Object obj, LogicalType logicalType) {
        validateLogicalType(obj, logicalType, Time.class, LocalTime.class);
        return (Integer) JstlTypeConverter.INSTANCE.coerceToType(obj, Integer.class);
    }

    private BigDecimal getAvroDecimal(Object obj, LogicalType logicalType) {
        validateLogicalType(obj, logicalType, BigDecimal.class);
        return (BigDecimal) JstlTypeConverter.INSTANCE.coerceToType(obj, BigDecimal.class);
    }

    void validateLogicalType(Object obj, LogicalType logicalType, Class<?>... clsArr) {
        for (Class<?> cls : clsArr) {
            if (obj.getClass().equals(cls)) {
                return;
            }
        }
        throw new IllegalArgumentException(String.format("Invalid java type %s for logical type %s", obj.getClass(), logicalType));
    }

    private Schema.Field createAvroField(ComputeField computeField, ComputeFieldType computeFieldType, Object obj) {
        Schema avroSchema = getAvroSchema(computeFieldType, obj);
        Object obj2 = null;
        if (computeField.isOptional()) {
            avroSchema = (Schema) ((SchemaBuilder.UnionAccumulator) ((SchemaBuilder.UnionAccumulator) SchemaBuilder.unionOf().nullType()).and().type(avroSchema)).endUnion();
            obj2 = Schema.Field.NULL_DEFAULT_VALUE;
        }
        return new Schema.Field(computeField.getName(), avroSchema, (String) null, obj2);
    }

    private Schema getAvroSchema(ComputeFieldType computeFieldType, Object obj) {
        Schema.Type type;
        switch (computeFieldType) {
            case STRING:
                type = Schema.Type.STRING;
                break;
            case INT8:
            case INT16:
            case INT32:
            case DATE:
            case LOCAL_DATE:
            case TIME:
            case LOCAL_TIME:
                type = Schema.Type.INT;
                break;
            case INT64:
            case DATETIME:
            case TIMESTAMP:
            case INSTANT:
            case LOCAL_DATE_TIME:
                type = Schema.Type.LONG;
                break;
            case FLOAT:
                type = Schema.Type.FLOAT;
                break;
            case DOUBLE:
                type = Schema.Type.DOUBLE;
                break;
            case BOOLEAN:
                type = Schema.Type.BOOLEAN;
                break;
            case BYTES:
                type = Schema.Type.BYTES;
                break;
            case DECIMAL:
                BigDecimal bigDecimal = (BigDecimal) obj;
                return LogicalTypes.decimal(bigDecimal.precision(), bigDecimal.scale()).addToSchema(Schema.create(Schema.Type.BYTES));
            default:
                throw new UnsupportedOperationException("Unsupported compute field type: " + computeFieldType);
        }
        Schema.Type type2 = type;
        return this.fieldTypeToAvroSchemaCache.computeIfAbsent(computeFieldType, computeFieldType2 -> {
            Schema create = Schema.create(type2);
            switch (computeFieldType2) {
                case DATE:
                case LOCAL_DATE:
                    return LogicalTypes.date().addToSchema(create);
                case TIME:
                case LOCAL_TIME:
                    return LogicalTypes.timeMillis().addToSchema(create);
                case INT64:
                default:
                    return create;
                case DATETIME:
                case TIMESTAMP:
                case INSTANT:
                case LOCAL_DATE_TIME:
                    return LogicalTypes.timestampMillis().addToSchema(create);
            }
        });
    }

    private TransformSchemaType getPrimitiveSchema(ComputeFieldType computeFieldType) {
        switch (computeFieldType) {
            case STRING:
                return TransformSchemaType.STRING;
            case INT8:
                return TransformSchemaType.INT8;
            case INT16:
                return TransformSchemaType.INT16;
            case INT32:
                return TransformSchemaType.INT32;
            case DATE:
                return TransformSchemaType.DATE;
            case LOCAL_DATE:
                return TransformSchemaType.LOCAL_DATE;
            case TIME:
                return TransformSchemaType.TIME;
            case LOCAL_TIME:
                return TransformSchemaType.LOCAL_TIME;
            case INT64:
                return TransformSchemaType.INT64;
            case DATETIME:
            case INSTANT:
                return TransformSchemaType.INSTANT;
            case TIMESTAMP:
                return TransformSchemaType.TIMESTAMP;
            case LOCAL_DATE_TIME:
                return TransformSchemaType.LOCAL_DATE_TIME;
            case FLOAT:
                return TransformSchemaType.FLOAT;
            case DOUBLE:
                return TransformSchemaType.DOUBLE;
            case BOOLEAN:
                return TransformSchemaType.BOOLEAN;
            case BYTES:
                return TransformSchemaType.BYTES;
            default:
                throw new UnsupportedOperationException("Unsupported compute field type: " + computeFieldType);
        }
    }

    private TransformSchemaType getPrimitiveSchema(Object obj) {
        if (obj == null) {
            throw new UnsupportedOperationException("Cannot get schema from null value");
        }
        if (obj.getClass().equals(String.class)) {
            return TransformSchemaType.STRING;
        }
        if (obj.getClass().equals(byte[].class)) {
            return TransformSchemaType.BYTES;
        }
        if (obj.getClass().equals(Boolean.class)) {
            return TransformSchemaType.BOOLEAN;
        }
        if (obj.getClass().equals(Byte.class)) {
            return TransformSchemaType.INT8;
        }
        if (obj.getClass().equals(Short.class)) {
            return TransformSchemaType.INT16;
        }
        if (obj.getClass().equals(Integer.class)) {
            return TransformSchemaType.INT32;
        }
        if (obj.getClass().equals(Long.class)) {
            return TransformSchemaType.INT64;
        }
        if (obj.getClass().equals(Float.class)) {
            return TransformSchemaType.FLOAT;
        }
        if (obj.getClass().equals(Double.class)) {
            return TransformSchemaType.DOUBLE;
        }
        if (obj.getClass().equals(Date.class)) {
            return TransformSchemaType.DATE;
        }
        if (obj.getClass().equals(Timestamp.class)) {
            return TransformSchemaType.TIMESTAMP;
        }
        if (obj.getClass().equals(Time.class)) {
            return TransformSchemaType.TIME;
        }
        if (obj.getClass().equals(LocalDateTime.class)) {
            return TransformSchemaType.LOCAL_DATE_TIME;
        }
        if (obj.getClass().equals(LocalDate.class)) {
            return TransformSchemaType.LOCAL_DATE;
        }
        if (obj.getClass().equals(LocalTime.class)) {
            return TransformSchemaType.LOCAL_TIME;
        }
        if (obj.getClass().equals(Instant.class)) {
            return TransformSchemaType.INSTANT;
        }
        throw new UnsupportedOperationException("Got an unsupported type: " + obj.getClass());
    }

    private ComputeFieldType getFieldType(Object obj) {
        if (obj == null) {
            throw new UnsupportedOperationException("Cannot get field type from null value");
        }
        if (obj instanceof CharSequence) {
            return ComputeFieldType.STRING;
        }
        if ((obj instanceof ByteBuffer) || obj.getClass().equals(byte[].class)) {
            return ComputeFieldType.BYTES;
        }
        if (obj.getClass().equals(Boolean.class)) {
            return ComputeFieldType.BOOLEAN;
        }
        if (obj.getClass().equals(Byte.class)) {
            return ComputeFieldType.INT8;
        }
        if (obj.getClass().equals(Short.class)) {
            return ComputeFieldType.INT16;
        }
        if (obj.getClass().equals(Integer.class)) {
            return ComputeFieldType.INT32;
        }
        if (obj.getClass().equals(Long.class)) {
            return ComputeFieldType.INT64;
        }
        if (obj.getClass().equals(Float.class)) {
            return ComputeFieldType.FLOAT;
        }
        if (obj.getClass().equals(Double.class)) {
            return ComputeFieldType.DOUBLE;
        }
        if (obj.getClass().equals(Date.class)) {
            return ComputeFieldType.DATE;
        }
        if (obj.getClass().equals(Timestamp.class)) {
            return ComputeFieldType.TIMESTAMP;
        }
        if (obj.getClass().equals(Time.class)) {
            return ComputeFieldType.TIME;
        }
        if (obj.getClass().equals(LocalDateTime.class)) {
            return ComputeFieldType.LOCAL_DATE_TIME;
        }
        if (obj.getClass().equals(LocalDate.class)) {
            return ComputeFieldType.LOCAL_DATE;
        }
        if (obj.getClass().equals(LocalTime.class)) {
            return ComputeFieldType.LOCAL_TIME;
        }
        if (obj.getClass().equals(Instant.class)) {
            return ComputeFieldType.INSTANT;
        }
        if (obj.getClass().equals(BigDecimal.class)) {
            return ComputeFieldType.DECIMAL;
        }
        throw new UnsupportedOperationException("Got an unsupported type: " + obj.getClass());
    }

    private static List<ComputeField> $default$fields() {
        return new ArrayList();
    }

    ComputeStep(List<ComputeField> list) {
        this.fields = list;
    }

    public static ComputeStepBuilder builder() {
        return new ComputeStepBuilder();
    }
}
