package com.dtsx.astra.sdk.cassio;

import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import io.stargate.sdk.utils.AnsiUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/dtsx/astra/sdk/cassio/ClusteredMetadataVectorTable.class */
public class ClusteredMetadataVectorTable extends AbstractCassandraTable<ClusteredMetadataVectorRecord> {
    private static final Logger log = LoggerFactory.getLogger(ClusteredMetadataVectorTable.class);
    private final int vectorDimension;
    private final CassandraSimilarityMetric similarityMetric;
    private PreparedStatement findPartitionStatement;
    private PreparedStatement deletePartitionStatement;
    private PreparedStatement findRowStatement;
    private PreparedStatement deleteRowStatement;
    private PreparedStatement insertRowStatement;

    /* loaded from: input_file:com/dtsx/astra/sdk/cassio/ClusteredMetadataVectorTable$Builder.class */
    public static class Builder {
        private CqlSession session;
        private String keyspaceName;
        private String tableName;
        private Integer vectorDimension;
        private CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;

        public Builder withSession(CqlSession cqlSession) {
            this.session = cqlSession;
            return this;
        }

        public Builder withKeyspaceName(String str) {
            this.keyspaceName = str;
            return this;
        }

        public Builder withTableName(String str) {
            this.tableName = str;
            return this;
        }

        public Builder withVectorDimension(Integer num) {
            this.vectorDimension = num;
            return this;
        }

        public Builder withMetric(CassandraSimilarityMetric cassandraSimilarityMetric) {
            this.metric = cassandraSimilarityMetric;
            return this;
        }

        public ClusteredMetadataVectorTable build() {
            return new ClusteredMetadataVectorTable(this.session, this.keyspaceName, this.tableName, this.vectorDimension, this.metric);
        }
    }

    public ClusteredMetadataVectorTable(@NonNull CqlSession cqlSession, @NonNull String str, @NonNull String str2, @NonNull Integer num, @NonNull CassandraSimilarityMetric cassandraSimilarityMetric) {
        super(cqlSession, str, str2);
        if (cqlSession == null) {
            throw new NullPointerException("session is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("keyspaceName is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("tableName is marked non-null but is null");
        }
        if (num == null) {
            throw new NullPointerException("vectorDimension is marked non-null but is null");
        }
        if (cassandraSimilarityMetric == null) {
            throw new NullPointerException("metric is marked non-null but is null");
        }
        this.vectorDimension = num.intValue();
        this.similarityMetric = cassandraSimilarityMetric;
    }

    public ClusteredMetadataVectorTable(CqlSession cqlSession, String str, String str2, int i) {
        this(cqlSession, str, str2, Integer.valueOf(i), CassandraSimilarityMetric.COSINE);
    }

    public static Builder builder() {
        return new Builder();
    }

    private synchronized void prepareStatements() {
        if (this.findPartitionStatement == null) {
            this.findPartitionStatement = this.cqlSession.prepare("select * from " + this.keyspaceName + "." + this.tableName + " where partition_id = ? ");
            this.deletePartitionStatement = this.cqlSession.prepare("delete from " + this.keyspaceName + "." + this.tableName + " where partition_id = ? ");
            this.findRowStatement = this.cqlSession.prepare("select * from " + this.keyspaceName + "." + this.tableName + " where partition_id = ?  and row_id = ? ");
            this.deleteRowStatement = this.cqlSession.prepare("delete from " + this.keyspaceName + "." + this.tableName + " where partition_id = ?  and row_id = ? ");
            this.insertRowStatement = this.cqlSession.prepare("INSERT INTO " + this.keyspaceName + "." + this.tableName + " (partition_id,row_id,vector,attributes_blob,body_blob,metadata_s) VALUES (?,?,?,?,?,?)");
        }
    }

    @Override // com.dtsx.astra.sdk.cassio.AbstractCassandraTable
    public void create() {
        this.cqlSession.execute("CREATE TABLE IF NOT EXISTS " + this.tableName + " (partition_id text, row_id timeuuid, attributes_blob text, body_blob text, metadata_s map<text, text>, vector vector<float, " + this.vectorDimension + ">, PRIMARY KEY ((partition_id), row_id)) WITH CLUSTERING ORDER BY (row_id DESC)");
        this.cqlSession.execute("CREATE CUSTOM INDEX IF NOT EXISTS idx_vector_" + this.tableName + " ON " + this.tableName + " (vector) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '" + this.similarityMetric.getOption() + "'};");
        log.info("+ Index '{}' has been created (if needed).", "idx_vector_" + this.tableName);
        this.cqlSession.execute("CREATE CUSTOM INDEX IF NOT EXISTS eidx_metadata_s_" + this.tableName + " ON " + this.tableName + " (ENTRIES(metadata_s)) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' ");
        log.info("+ Index '{}' has been created (if needed).", "eidx_metadata_s_" + this.tableName);
    }

    public List<ClusteredMetadataVectorRecord> findPartition(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("partitionDd is marked non-null but is null");
        }
        prepareStatements();
        return (List) this.cqlSession.execute(this.findPartitionStatement.bind(new Object[]{str})).all().stream().map(this::mapRow).collect(Collectors.toList());
    }

    public void deletePartition(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("partitionDd is marked non-null but is null");
        }
        prepareStatements();
        this.cqlSession.execute(this.deletePartitionStatement.bind(new Object[]{str}));
    }

    public Optional<ClusteredMetadataVectorRecord> get(String str, UUID uuid) {
        prepareStatements();
        return Optional.ofNullable((Row) this.cqlSession.execute(this.findRowStatement.bind(new Object[]{str, uuid})).one()).map(this::mapRow);
    }

    public void delete(String str, UUID uuid) {
        prepareStatements();
        this.cqlSession.execute(this.deleteRowStatement.bind(new Object[]{str, uuid}));
    }

    public void save(ClusteredMetadataVectorRecord clusteredMetadataVectorRecord) {
        put(clusteredMetadataVectorRecord);
    }

    @Override // com.dtsx.astra.sdk.cassio.AbstractCassandraTable
    public void put(ClusteredMetadataVectorRecord clusteredMetadataVectorRecord) {
        prepareStatements();
        this.cqlSession.execute(this.insertRowStatement.bind(new Object[]{clusteredMetadataVectorRecord.getPartitionId(), clusteredMetadataVectorRecord.getRowId(), CqlVector.newInstance(clusteredMetadataVectorRecord.getVector()), clusteredMetadataVectorRecord.getAttributes(), clusteredMetadataVectorRecord.getBody(), clusteredMetadataVectorRecord.getMetadata()}));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.dtsx.astra.sdk.cassio.AbstractCassandraTable
    public ClusteredMetadataVectorRecord mapRow(Row row) {
        if (row == null) {
            return null;
        }
        ClusteredMetadataVectorRecord clusteredMetadataVectorRecord = new ClusteredMetadataVectorRecord();
        clusteredMetadataVectorRecord.setPartitionId(row.getString(AbstractCassandraTable.PARTITION_ID));
        clusteredMetadataVectorRecord.setVector((List) ((CqlVector) Objects.requireNonNull(row.getObject(AbstractCassandraTable.VECTOR))).stream().collect(Collectors.toList()));
        clusteredMetadataVectorRecord.setRowId(row.getUuid(AbstractCassandraTable.ROW_ID));
        clusteredMetadataVectorRecord.setBody(row.getString(AbstractCassandraTable.BODY_BLOB));
        if (row.getColumnDefinitions().contains(AbstractCassandraTable.ATTRIBUTES_BLOB)) {
            clusteredMetadataVectorRecord.setAttributes(row.getString(AbstractCassandraTable.ATTRIBUTES_BLOB));
        }
        if (row.getColumnDefinitions().contains(AbstractCassandraTable.METADATA_S)) {
            clusteredMetadataVectorRecord.setMetadata(row.getMap(AbstractCassandraTable.METADATA_S, String.class, String.class));
        }
        return clusteredMetadataVectorRecord;
    }

    public List<AnnResult<ClusteredMetadataVectorRecord>> similaritySearch(AnnQuery annQuery) {
        StringBuilder append = new StringBuilder("SELECT ").append(String.join(",", AbstractCassandraTable.PARTITION_ID, AbstractCassandraTable.ROW_ID, AbstractCassandraTable.VECTOR, AbstractCassandraTable.BODY_BLOB, AbstractCassandraTable.ATTRIBUTES_BLOB, AbstractCassandraTable.METADATA_S, "")).append(annQuery.getMetric().getFunction()).append("(vector, :vector) as ").append(AbstractCassandraTable.COLUMN_SIMILARITY).append(" FROM ").append(this.tableName);
        ArrayList arrayList = new ArrayList();
        if (annQuery.getMetaData() != null && !annQuery.getMetaData().isEmpty()) {
            if (annQuery.getMetaData().containsKey(AbstractCassandraTable.PARTITION_ID)) {
                arrayList.add("partition_id = '" + annQuery.getMetaData().get(AbstractCassandraTable.PARTITION_ID) + "'");
                annQuery.getMetaData().remove(AbstractCassandraTable.PARTITION_ID);
            }
            annQuery.getMetaData().forEach((str, str2) -> {
                arrayList.add("metadata_s['" + str + "'] = '" + str2 + "'");
            });
        }
        if (!arrayList.isEmpty()) {
            append.append(" WHERE ").append(String.join(" AND ", arrayList));
        }
        append.append(" ORDER BY vector ANN OF :vector LIMIT :maxRecord");
        Logger logger = log;
        Object[] objArr = new Object[3];
        objArr[0] = AnsiUtils.yellow(this.tableName);
        objArr[1] = AnsiUtils.cyan("[" + annQuery.getEmbeddings().size() + "]");
        objArr[2] = AnsiUtils.cyan((annQuery.getRecordCount() > 0 ? annQuery.getRecordCount() : 4));
        logger.debug("Query on table '{}' with vector size '{}' and max record='{}'", objArr);
        return (List) this.cqlSession.execute(SimpleStatement.builder(append.toString()).addNamedValue(AbstractCassandraTable.VECTOR, CqlVector.newInstance(annQuery.getEmbeddings())).addNamedValue("maxRecord", Integer.valueOf(annQuery.getRecordCount() > 0 ? annQuery.getRecordCount() : 4)).build()).all().stream().map(this::mapResult).filter(annResult -> {
            return ((double) annResult.getSimilarity()) >= annQuery.getThreshold();
        }).collect(Collectors.toList());
    }

    private AnnResult<ClusteredMetadataVectorRecord> mapResult(Row row) {
        if (row == null) {
            return null;
        }
        AnnResult<ClusteredMetadataVectorRecord> annResult = new AnnResult<>();
        annResult.setEmbedded(mapRow(row));
        annResult.setSimilarity(row.getFloat(AbstractCassandraTable.COLUMN_SIMILARITY));
        log.debug("Result similarity '{}' for embedded id='{}'", Float.valueOf(annResult.getSimilarity()), annResult.getEmbedded().getRowId());
        return annResult;
    }
}
