//! Database storage and retrieval for document embeddings. use futures::future::try_join_all; use snafu::{ResultExt, Snafu}; use sqlx::{FromRow, Pool, postgres::PgPoolOptions}; use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings}; /// A document chunk with similarity score from vector search. #[derive(Debug, FromRow)] pub struct DocumentMatch { /// Calibre book ID. pub book_id: i64, /// Text content of the chunk. pub text_chunk: String, /// Cosine similarity score (0.0 to 1.0). pub similarity: f64, } /// PostgreSQL database connection pool and operations. pub struct Postgres { /// Connection pool for database operations. pool: Pool, } /// Error when connecting to database. #[derive(Debug, Snafu)] #[snafu(display("Failed to connect to database."))] pub struct ConnectionError { source: sqlx::Error, } /// Error when retrieving document version. #[derive(Debug, Snafu)] #[snafu(display("Failed to get highest document version."))] pub struct VersionError { source: sqlx::Error, } /// Error when storing embeddings. #[derive(Debug, Snafu)] #[snafu(display("Failed to store embeddings."))] pub struct EmbeddingError { source: sqlx::Error, } /// Error when querying for similar documents. #[derive(Debug, Snafu)] #[snafu(display("Failed to query similarity data."))] pub struct QueryError { source: sqlx::Error, } /// Error when updating document hash. #[derive(Debug, Snafu)] #[snafu(display("Failed to update document hash."))] pub struct UpdateHashError { source: sqlx::Error, } /// Error when checking if document needs update. #[derive(Debug, Snafu)] #[snafu(display("Failed to check if document needs update."))] pub struct CheckUpdateError { source: sqlx::Error, } /// Error when removing old document versions. #[derive(Debug, Snafu)] #[snafu(display("Failed to delete old document versions."))] pub struct RemoveOldVersionsError { source: sqlx::Error, } impl Postgres { /// Create a new database connection pool. pub async fn connect(db_url: &str, max_connections: u32) -> Result { let pool = PgPoolOptions::new() .max_connections(max_connections) .connect(db_url) .await .context(ConnectionSnafu)?; Ok(Self { pool }) } /// Get the highest version number for a book's embeddings. pub async fn embedding_version(&self, book_id: i64) -> Result { let version = sqlx::query_scalar( "SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1", ) .bind(book_id) .fetch_one(&self.pool) .await .context(VersionSnafu)?; Ok(version) } /// Store multiple embeddings for a book in parallel. pub async fn add_multiple_embeddings( &self, book_id: i64, version: i32, embeddings: Vec, ) -> Result<(), EmbeddingError> { let inserts: Vec<_> = embeddings .into_iter() .map(|e| self.add_embeddings(book_id, version, e)) .collect(); try_join_all(inserts).await?; Ok(()) } /// Store a single embedding for a book. pub async fn add_embeddings( &self, book_id: i64, version: i32, embeddings: Embeddings, ) -> Result<(), EmbeddingError> { let vector_string = embeddings_to_vec_string(embeddings.0); sqlx::query( "INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)", ) .bind(book_id) .bind(version) .bind(vector_string) .bind(embeddings.2) .execute(&self.pool) .await .context(EmbeddingSnafu)?; Ok(()) } /// Find documents similar to the query embedding using cosine similarity. pub async fn query( &self, query: Embeddings, limit: i32, ) -> Result, QueryError> { let vector_string = embeddings_to_vec_string(query.0); let books = sqlx::query_as::<_, DocumentMatch>( "SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity FROM documents ORDER BY embedding <=> $1::vector LIMIT $2", ) .bind(vector_string) .bind(limit) .fetch_all(&self.pool) .await .context(QuerySnafu)?; Ok(books) } /// Update or insert the hash for a book. pub async fn update_hash( &self, book_id: i64, hash: [u8; SHA256_LENGTH], ) -> Result<(), UpdateHashError> { sqlx::query( "INSERT INTO hashes (book_id, hash) VALUES ($1, $2) ON CONFLICT (book_id) DO UPDATE SET hash = $2", ) .bind(book_id) .bind(&hash[..]) .execute(&self.pool) .await .context(UpdateHashSnafu)?; Ok(()) } /// Check if a book needs to be reprocessed based on its hash. pub async fn book_needs_update( &self, book_id: i64, hash: [u8; SHA256_LENGTH], ) -> Result { let exists = sqlx::query_scalar::<_, bool>( "SELECT NOT EXISTS ( SELECT 1 FROM hashes WHERE book_id = $1 AND hash = $2 )", ) .bind(book_id) .bind(&hash[..]) .fetch_one(&self.pool) .await .context(CheckUpdateSnafu)?; Ok(exists) } /// Remove all but the latest version of embeddings for a book. pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> { sqlx::query( "DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)", ) .bind(book_id) .execute(&self.pool) .await .context(RemoveOldVersionsSnafu)?; Ok(()) } } /// Convert a vector of floats to PostgreSQL vector format string. fn embeddings_to_vec_string(vector: Vec) -> String { format!( "[{}]", vector .iter() .map(|x| x.to_string()) .collect::>() .join(",") ) }