From 6a5b309391f978763e3a582f7c439e8319a45b84 Mon Sep 17 00:00:00 2001 From: Sebastian Hugentobler Date: Tue, 1 Jul 2025 21:10:35 +0200 Subject: [PATCH] WIP job queue --- Cargo.lock | 92 ++++++++++++ Cargo.toml | 3 +- migrations/20250701133953_jobs.sql | 12 ++ src/api.rs | 5 + src/api/jobs.rs | 81 +++++++++++ src/api/query.rs | 126 ++++++++++++----- src/api/routes.rs | 3 +- src/api/state.rs | 4 + src/main.rs | 4 +- src/query.rs | 11 +- src/scanner.rs | 24 ++-- src/storage.rs | 220 +++-------------------------- src/storage/jobs.rs | 143 +++++++++++++++++++ src/storage/queries.rs | 211 +++++++++++++++++++++++++++ src/text_encoder.rs | 2 +- 15 files changed, 685 insertions(+), 256 deletions(-) create mode 100644 migrations/20250701133953_jobs.sql create mode 100644 src/api/jobs.rs create mode 100644 src/storage/jobs.rs create mode 100644 src/storage/queries.rs diff --git a/Cargo.lock b/Cargo.lock index c798432..0c85819 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,6 +158,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -449,6 +458,15 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "cobs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1" +dependencies = [ + "thiserror", +] + [[package]] name = "colorchoice" version = "1.0.4" @@ -543,6 +561,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -842,6 +866,18 @@ dependencies = [ "serde", ] +[[package]] +name = "embedded-io" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + [[package]] name = "encode_unicode" version = "1.0.0" @@ -1171,6 +1207,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1200,6 +1245,20 @@ dependencies = [ "hashbrown 0.15.3", ] +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "serde", + "spin", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.5.0" @@ -1788,6 +1847,7 @@ dependencies = [ "futures", "ignore", "lopdf", + "postcard", "quick-xml", "rayon", "ring", @@ -2411,6 +2471,19 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "postcard" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c1de96e20f51df24ca73cafcc4690e044854d803259db27a00a461cb3b9d17a" +dependencies = [ + "cobs", + "embedded-io 0.4.0", + "embedded-io 0.6.1", + "heapless", + "serde", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2745,6 +2818,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustfft" version = "6.3.0" @@ -2877,6 +2959,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + [[package]] name = "serde" version = "1.0.219" @@ -3125,6 +3213,7 @@ dependencies = [ "tokio-stream", "tracing", "url", + "uuid", "webpki-roots 0.26.11", ] @@ -3205,6 +3294,7 @@ dependencies = [ "stringprep", "thiserror", "tracing", + "uuid", "whoami", ] @@ -3242,6 +3332,7 @@ dependencies = [ "stringprep", "thiserror", "tracing", + "uuid", "whoami", ] @@ -3267,6 +3358,7 @@ dependencies = [ "thiserror", "tracing", "url", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a2e3b64..a3eed13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ axum = { version = "0.8.4", features = ["http2", "tracing"] } clap = { version = "4.5.40", features = ["derive"] } futures = "0.3.31" lopdf = "0.36.0" +postcard = { version = "1.1.2", features = ["alloc"] } quick-xml = "0.38.0" rayon = "1.10.0" ring = "0.17.14" @@ -16,7 +17,7 @@ scraper = "0.23.1" serde = { version = "1.0.219", features = ["derive"] } sha2 = "0.10.9" snafu = { version = "0.8.6", features = ["rust_1_81"] } -sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] } +sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive", "uuid", "macros" ] } tokenizers = "0.21.2" tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]} tower-http = { version = "0.6.6", features = ["trace"] } diff --git a/migrations/20250701133953_jobs.sql b/migrations/20250701133953_jobs.sql new file mode 100644 index 0000000..d2f0e5f --- /dev/null +++ b/migrations/20250701133953_jobs.sql @@ -0,0 +1,12 @@ +CREATE TYPE job_status AS ENUM ('running', 'completed', 'failed'); + +CREATE TABLE jobs ( + id UUID PRIMARY KEY, + status job_status NOT NULL DEFAULT 'running', + error TEXT, + result BYTEA, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL +); + +CREATE INDEX idx_jobs_status ON jobs(status); +CREATE INDEX idx_jobs_expires_at ON jobs(expires_at); diff --git a/src/api.rs b/src/api.rs index 0aa7c7e..47c7683 100644 --- a/src/api.rs +++ b/src/api.rs @@ -4,8 +4,10 @@ use std::{ io, net::{AddrParseError, SocketAddr}, str::FromStr, + sync::Arc, }; +use jobs::JobManager; use snafu::{ResultExt, Snafu}; use state::AppState; use tokio::net::TcpListener; @@ -17,6 +19,7 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer}; pub mod code; pub mod error; +pub mod jobs; pub mod query; pub mod routes; pub mod state; @@ -55,12 +58,14 @@ pub async fn serve( reranker: TextEncoder, chunk_size: usize, ) -> Result<(), ServeError> { + let jobs = JobManager::new(Arc::new(db.jobs.clone())); let state = AppState { db, tokenizer, embedder, reranker, chunk_size, + jobs, }; let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) diff --git a/src/api/jobs.rs b/src/api/jobs.rs new file mode 100644 index 0000000..ae9bdb5 --- /dev/null +++ b/src/api/jobs.rs @@ -0,0 +1,81 @@ +use snafu::Report; +use snafu::{ResultExt, Snafu}; +use std::{sync::Arc, time::Duration}; +use tracing::error; + +use uuid::Uuid; + +use crate::storage; +use crate::storage::jobs::AddError; + +#[derive(Clone, Debug)] +pub struct JobManager { + db: Arc, +} + +/// Errors that occur when running a background job. +#[derive(Debug, Snafu)] +pub enum ExecuteError { + #[snafu(display("Failed to add new background job."))] + Add { source: AddError }, +} + +impl JobManager { + pub fn new(db: Arc) -> Self { + Self { db } + } + + pub async fn execute(&self, task: F) -> Result + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future> + Send, + T: serde::Serialize + Send + Sync + 'static, + E: snafu::Error + Send + Sync + 'static, + { + let job_id = Uuid::new_v4(); + self.db.as_ref().add(&job_id).await.context(AddSnafu)?; + let db = self.db.clone(); + + tokio::spawn(async move { + match task().await { + Ok(result) => success(&job_id, &result, &db).await, + Err(error) => failure(&job_id, error, &db).await, + } + + tokio::time::sleep(Duration::from_secs(300)).await; + }); + + Ok(job_id) + } +} + +async fn success(job_id: &Uuid, result: &T, db: &storage::jobs::Jobs) +where + T: serde::Serialize + Send + Sync + 'static, +{ + let data = match postcard::to_allocvec(result) { + Ok(data) => data, + Err(error) => { + failure(job_id, error, db).await; + return; + } + }; + + if let Err(error) = db.finish(job_id, &data).await { + failure(job_id, error, db).await; + } +} + +async fn failure(job_id: &Uuid, error: E, db: &storage::jobs::Jobs) +where + E: snafu::Error + Send + Sync + 'static, +{ + let error = Report::from_error(&error); + error!("{job_id}: {error}"); + if let Err(error) = db.fail(job_id, &error.to_string()).await { + error!( + "Failed to save job {job_id} result: {}", + Report::from_error(&error) + ); + } +} diff --git a/src/api/query.rs b/src/api/query.rs index 4cda3c8..2661a27 100644 --- a/src/api/query.rs +++ b/src/api/query.rs @@ -1,18 +1,22 @@ //! Query endpoint handlers and response types. -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use axum::{ Json, - extract::{Query, State}, + extract::{Path, Query, State}, http::StatusCode, }; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu, ensure}; use utoipa::ToSchema; +use uuid::Uuid; -use super::{TAG, error::HttpStatus, state::AppState}; -use crate::{http_error, query, storage::DocumentMatch}; +use super::{TAG, error::HttpStatus, jobs, state::AppState}; +use crate::{ + http_error, query, + storage::{self, queries::DocumentMatch}, +}; const MAX_LIMIT: usize = 10; @@ -22,14 +26,14 @@ pub enum QueryError { #[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))] Limit, #[snafu(display("Failed to run query."))] - Query { source: query::AskError }, + QueryExecute { source: jobs::ExecuteError }, } impl HttpStatus for QueryError { fn status_code(&self) -> StatusCode { match self { QueryError::Limit => StatusCode::BAD_REQUEST, - QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, + QueryError::QueryExecute { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -46,25 +50,9 @@ pub struct QueryParams { /// Response format for successful query requests. #[derive(Debug, Serialize, ToSchema)] pub struct QueryResponse { + pub status: String, /// List of matching document chunks. pub results: Vec, - /// Total number of results returned. - pub count: usize, - /// Original query text that was searched. - pub query: String, -} - -impl From<(Vec, String)> for QueryResponse { - fn from((documents, query): (Vec, String)) -> Self { - let results: Vec = documents.into_iter().map(Into::into).collect(); - let count = results.len(); - - Self { - results, - count, - query, - } - } } /// A single document search result. @@ -88,6 +76,13 @@ impl From for DocumentResult { } } +/// Response format for successful query requests. +#[derive(Debug, Serialize, ToSchema)] +pub struct QueryStartResponse { + ///.Job id. + pub id: String, +} + /// Execute a semantic search query against the document database. #[utoipa::path( post, @@ -98,31 +93,84 @@ impl From for DocumentResult { ), request_body = String, responses( - (status = OK, body = QueryResponse), + (status = 202, body = QueryStartResponse), (status = 400, description = "Wrong parameter."), (status = 500, description = "Failure to query database.") ) )] -pub async fn route( +pub async fn start( Query(params): Query, State(state): State>, body: String, -) -> Result, QueryError> { +) -> Result, QueryError> { let limit = params.limit.unwrap_or(5); ensure!(limit <= MAX_LIMIT, LimitSnafu); - let results = query::ask( - &body, - &state.db, - &state.tokenizer, - &state.embedder, - &state.reranker, - state.chunk_size, - limit, - ) - .await - .context(QuerySnafu)?; - let response = QueryResponse::from((results, body)); + let jobs = state.jobs.clone(); + let id = jobs + .execute(move || async move { + query::ask( + &body, + &state.db.queries, + &state.tokenizer, + &state.embedder, + &state.reranker, + state.chunk_size, + limit, + ) + .await + }) + .await + .context(QueryExecuteSnafu)?; - Ok(Json(response)) + Ok(Json(QueryStartResponse { id: id.to_string() })) +} + +/// Errors that occur during query retrieval. +#[derive(Debug, Snafu)] +pub enum RetrieveError { + #[snafu(display("Failed to retrieve job from database."))] + Db { + source: storage::jobs::RetrieveError, + }, + #[snafu(display("No job with id {id}."))] + NoJob { id: Uuid }, + #[snafu(display("'{id}' is not a valid v4 UUID."))] + Uuid { source: uuid::Error, id: String }, +} + +impl HttpStatus for RetrieveError { + fn status_code(&self) -> StatusCode { + match self { + RetrieveError::Db { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, + RetrieveError::NoJob { id: _ } => StatusCode::NOT_FOUND, + RetrieveError::Uuid { source: _, id: _ } => StatusCode::BAD_REQUEST, + } + } +} + +http_error!(RetrieveError); + +#[utoipa::path( + get, + path = "/query/{id}", + tag = TAG, + responses( + (status = OK, body = QueryResponse), + (status = 404, description = "No job with the requested id."), + (status = 500, description = "Failure to query database.") + ) +)] +pub async fn result( + Path(id): Path, + State(state): State>, +) -> Result, RetrieveError> { + let id = Uuid::from_str(&id).context(UuidSnafu { id })?; + match state.db.jobs.retrieve(&id).await.context(DbSnafu)? { + Some((status, results)) => Ok(Json(QueryResponse { + status, + results: results.into_iter().map(Into::into).collect(), + })), + None => NoJobSnafu { id }.fail(), + } } diff --git a/src/api/routes.rs b/src/api/routes.rs index 9da1bc0..7db7bff 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -12,7 +12,8 @@ use crate::api::{code, query}; pub fn router(state: AppState) -> OpenApiRouter { let store = Arc::new(state); OpenApiRouter::new() - .routes(routes!(query::route, code::route)) + .routes(routes!(query::start, query::result)) + .routes(routes!(code::route)) .layer(TraceLayer::new_for_http()) .with_state(store) } diff --git a/src/api/state.rs b/src/api/state.rs index fc9239f..c0f5a80 100644 --- a/src/api/state.rs +++ b/src/api/state.rs @@ -2,6 +2,8 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer}; +use super::jobs::JobManager; + /// Application state shared across all HTTP request handlers. #[derive(Debug, Clone)] pub struct AppState { @@ -15,4 +17,6 @@ pub struct AppState { pub reranker: TextEncoder, /// Text chunk size in words for processing. pub chunk_size: usize, + /// Background jobs handling. + pub jobs: JobManager, } diff --git a/src/main.rs b/src/main.rs index 751e230..d605f8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,7 +75,7 @@ async fn main() -> Result<(), Error> { TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?; query::ask( &config.query, - &db, + &db.queries, &tokenizer, &embedder, &reranker, @@ -87,7 +87,7 @@ async fn main() -> Result<(), Error> { } Commands::Scan(config) => { scanner::scan( - &db, + &db.queries, &config.library_path, &tokenizer, &embedder, diff --git a/src/query.rs b/src/query.rs index 5a9adc3..599db3f 100644 --- a/src/query.rs +++ b/src/query.rs @@ -3,7 +3,10 @@ use snafu::{ResultExt, Snafu}; use crate::{ - storage::{self, DocumentMatch, Postgres}, + storage::{ + self, + queries::{DocumentMatch, Queries}, + }, text_encoder::{self, TextEncoder}, tokenize::{self, Tokenizer}, }; @@ -16,7 +19,9 @@ pub enum AskError { #[snafu(display("Failed to embed query."))] Embed { source: text_encoder::EmbedError }, #[snafu(display("Failed to retrieve similar documents."))] - Query { source: storage::QueryError }, + Query { + source: storage::queries::QueryError, + }, #[snafu(display("Failed to rerank documents."))] Rerank { source: text_encoder::RerankError }, } @@ -24,7 +29,7 @@ pub enum AskError { /// Process a user query and return ranked document matches. pub async fn ask( query: &str, - db: &Postgres, + db: &Queries, tokenizer: &Tokenizer, embedder: &TextEncoder, reranker: &TextEncoder, diff --git a/src/scanner.rs b/src/scanner.rs index 2c16d9a..3a9e4d4 100644 --- a/src/scanner.rs +++ b/src/scanner.rs @@ -9,14 +9,14 @@ use crate::{ calibre::{Book, BookIterator}, extractors::{self, extractor::ExtractionError}, hash, - storage::{self, Postgres}, + storage::{self, queries::Queries}, text_encoder::{self, TextEncoder}, tokenize::{self, Tokenizer}, }; /// Scan a Calibre library and process all books for embedding generation. pub async fn scan( - db: &Postgres, + db: &Queries, library_path: &PathBuf, tokenizer: &Tokenizer, text_encoder: &TextEncoder, @@ -43,7 +43,9 @@ pub enum ProcessBookError { #[snafu(display("Failed to hash book."))] BookHash { source: io::Error }, #[snafu(display("Failed to check in db if document needs update."))] - UpdateCheck { source: storage::CheckUpdateError }, + UpdateCheck { + source: storage::queries::CheckUpdateError, + }, #[snafu(display("Failed to open book."))] Open { source: io::Error }, #[snafu(display("Failed to extract document content."))] @@ -53,21 +55,27 @@ pub enum ProcessBookError { #[snafu(display("Failed to create embeddings from encodings."))] Embedding { source: text_encoder::EmbedError }, #[snafu(display("Failed to look up document versions."))] - DocumentVersion { source: storage::VersionError }, + DocumentVersion { + source: storage::queries::VersionError, + }, #[snafu(display("Failed to store embeddings in database."))] - SaveEmbeddings { source: storage::EmbeddingError }, + SaveEmbeddings { + source: storage::queries::EmbeddingError, + }, #[snafu(display("Failed to update document hash in database."))] - UpdateHash { source: storage::UpdateHashError }, + UpdateHash { + source: storage::queries::UpdateHashError, + }, #[snafu(display("Failed to delete old document versions in database."))] RemoveOldVersions { - source: storage::RemoveOldVersionsError, + source: storage::queries::RemoveOldVersionsError, }, } /// Process a single book: extract text, generate embeddings, and store in database. async fn process_book( book: &Book, - db: &Postgres, + db: &Queries, tokenizer: &Tokenizer, embedder: &TextEncoder, chunk_size: usize, diff --git a/src/storage.rs b/src/storage.rs index dbf3fe4..27f8089 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,27 +1,20 @@ //! Database storage and retrieval for document embeddings. -use futures::future::try_join_all; +use std::sync::Arc; + +use jobs::Jobs; +use queries::Queries; use snafu::{ResultExt, Snafu}; -use sqlx::{FromRow, Pool, postgres::PgPoolOptions}; +use sqlx::postgres::PgPoolOptions; -use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings}; - -/// A document chunk with similarity score from vector search. -#[derive(Debug, FromRow)] -pub struct DocumentMatch { - /// Calibre book ID. - pub book_id: i64, - /// Text content of the chunk. - pub text_chunk: String, - /// Cosine similarity score (0.0 to 1.0). - pub similarity: f64, -} +pub mod jobs; +pub mod queries; /// PostgreSQL database connection pool and operations. #[derive(Debug, Clone)] pub struct Postgres { - /// Connection pool for database operations. - pool: Pool, + pub jobs: Jobs, + pub queries: Queries, } /// Error when connecting to database. @@ -31,195 +24,20 @@ pub struct ConnectionError { source: sqlx::Error, } -/// Error when retrieving document version. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to get highest document version."))] -pub struct VersionError { - source: sqlx::Error, -} - -/// Error when storing embeddings. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to store embeddings."))] -pub struct EmbeddingError { - source: sqlx::Error, -} - -/// Error when querying for similar documents. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to query similarity data."))] -pub struct QueryError { - source: sqlx::Error, -} - -/// Error when updating document hash. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to update document hash."))] -pub struct UpdateHashError { - source: sqlx::Error, -} - -/// Error when checking if document needs update. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to check if document needs update."))] -pub struct CheckUpdateError { - source: sqlx::Error, -} - -/// Error when removing old document versions. -#[derive(Debug, Snafu)] -#[snafu(display("Failed to delete old document versions."))] -pub struct RemoveOldVersionsError { - source: sqlx::Error, -} - impl Postgres { /// Create a new database connection pool. pub async fn connect(db_url: &str, max_connections: u32) -> Result { - let pool = PgPoolOptions::new() - .max_connections(max_connections) - .connect(db_url) - .await - .context(ConnectionSnafu)?; + let pool = Arc::new( + PgPoolOptions::new() + .max_connections(max_connections) + .connect(db_url) + .await + .context(ConnectionSnafu)?, + ); - Ok(Self { pool }) - } + let jobs = Jobs::new(pool.clone()); + let queries = Queries::new(pool.clone()); - /// Get the highest version number for a book's embeddings. - pub async fn embedding_version(&self, book_id: i64) -> Result { - let version = sqlx::query_scalar( - "SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1", - ) - .bind(book_id) - .fetch_one(&self.pool) - .await - .context(VersionSnafu)?; - - Ok(version) - } - - /// Store multiple embeddings for a book in parallel. - pub async fn add_multiple_embeddings( - &self, - book_id: i64, - version: i32, - embeddings: Vec, - ) -> Result<(), EmbeddingError> { - let inserts: Vec<_> = embeddings - .into_iter() - .map(|e| self.add_embeddings(book_id, version, e)) - .collect(); - - try_join_all(inserts).await?; - Ok(()) - } - - /// Store a single embedding for a book. - pub async fn add_embeddings( - &self, - book_id: i64, - version: i32, - embeddings: Embeddings, - ) -> Result<(), EmbeddingError> { - let vector_string = embeddings_to_vec_string(embeddings.0); - sqlx::query( - "INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)", - ) - .bind(book_id) - .bind(version) - .bind(vector_string) - .bind(embeddings.2) - .execute(&self.pool) - .await - .context(EmbeddingSnafu)?; - - Ok(()) - } - - /// Find documents similar to the query embedding using cosine similarity. - pub async fn query( - &self, - query: Embeddings, - limit: i32, - ) -> Result, QueryError> { - let vector_string = embeddings_to_vec_string(query.0); - let books = sqlx::query_as::<_, DocumentMatch>( - "SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity - FROM documents - ORDER BY embedding <=> $1::vector - LIMIT $2", - ) - .bind(vector_string) - .bind(limit) - .fetch_all(&self.pool) - .await - .context(QuerySnafu)?; - - Ok(books) - } - - /// Update or insert the hash for a book. - pub async fn update_hash( - &self, - book_id: i64, - hash: [u8; SHA256_LENGTH], - ) -> Result<(), UpdateHashError> { - sqlx::query( - "INSERT INTO hashes (book_id, hash) VALUES ($1, $2) - ON CONFLICT (book_id) DO UPDATE SET hash = $2", - ) - .bind(book_id) - .bind(&hash[..]) - .execute(&self.pool) - .await - .context(UpdateHashSnafu)?; - - Ok(()) - } - - /// Check if a book needs to be reprocessed based on its hash. - pub async fn book_needs_update( - &self, - book_id: i64, - hash: [u8; SHA256_LENGTH], - ) -> Result { - let exists = sqlx::query_scalar::<_, bool>( - "SELECT NOT EXISTS ( - SELECT 1 FROM hashes - WHERE book_id = $1 AND hash = $2 - )", - ) - .bind(book_id) - .bind(&hash[..]) - .fetch_one(&self.pool) - .await - .context(CheckUpdateSnafu)?; - - Ok(exists) - } - - /// Remove all but the latest version of embeddings for a book. - pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> { - sqlx::query( - "DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)", - ) - .bind(book_id) - .execute(&self.pool) - .await - .context(RemoveOldVersionsSnafu)?; - - Ok(()) + Ok(Self { jobs, queries }) } } - -/// Convert a vector of floats to PostgreSQL vector format string. -fn embeddings_to_vec_string(vector: Vec) -> String { - format!( - "[{}]", - vector - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(",") - ) -} diff --git a/src/storage/jobs.rs b/src/storage/jobs.rs new file mode 100644 index 0000000..94942d8 --- /dev/null +++ b/src/storage/jobs.rs @@ -0,0 +1,143 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt, Snafu}; +use sqlx::{PgPool, Row, Type}; +use uuid::Uuid; + +use super::queries::DocumentMatch; + +#[derive(Debug, Clone)] +pub struct Jobs { + /// Connection pool for database operations. + pool: Arc, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Type)] +#[sqlx(type_name = "job_status", rename_all = "lowercase")] +pub enum JobStatus { + Running, + Completed, + Failed, +} + +impl std::fmt::Display for JobStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JobStatus::Running => write!(f, "running"), + JobStatus::Completed => write!(f, "completed"), + JobStatus::Failed => write!(f, "failed"), + } + } +} + +impl std::str::FromStr for JobStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "running" => Ok(JobStatus::Running), + "completed" => Ok(JobStatus::Completed), + "failed" => Ok(JobStatus::Failed), + _ => Err(format!("Invalid job status: {}", s)), + } + } +} + +/// Error adding a new job. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to add background job."))] +pub struct AddError { + source: sqlx::Error, +} + +/// Error when adding a result to a job. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to add result to background job."))] +pub struct ResultError { + source: sqlx::Error, +} + +/// Error when adding a result to a job. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to add failure to background job."))] +pub struct FailureError { + source: sqlx::Error, +} + +/// Errors that occur durin job retrieval. +#[derive(Debug, Snafu)] +pub enum RetrieveError { + #[snafu(display("Failed to retrieve background job data"))] + Retrieve { source: sqlx::Error }, + #[snafu(display("Failed to run query."))] + Deserialize { source: postcard::Error }, +} + +impl Jobs { + pub fn new(pool: Arc) -> Self { + Self { pool } + } + + /// Add a new job to the database and mark it as running. + pub async fn add(&self, id: &Uuid) -> Result<(), AddError> { + sqlx::query("INSERT INTO jobs (id, expires_at) VALUES ($1, NOW() + INTERVAL '24 hours')") + .bind(id) + .execute(self.pool.as_ref()) + .await + .context(AddSnafu)?; + + Ok(()) + } + + /// Add a result to a job and mark it as finished. + pub async fn finish(&self, id: &Uuid, data: &[u8]) -> Result<(), ResultError> { + sqlx::query("UPDATE jobs SET status = 'completed', result = $1 WHERE id = $2") + .bind(data) + .bind(id) + .execute(self.pool.as_ref()) + .await + .context(ResultSnafu)?; + + Ok(()) + } + + /// Add an error to a job and mark it as failed. + pub async fn fail(&self, id: &Uuid, error: &str) -> Result<(), FailureError> { + sqlx::query("UPDATE jobs SET status = 'failed', error = $1 WHERE id = $2") + .bind(error) + .bind(id) + .execute(self.pool.as_ref()) + .await + .context(FailureSnafu)?; + + Ok(()) + } + + pub async fn retrieve( + &self, + id: &Uuid, + ) -> Result)>, RetrieveError> { + let row = sqlx::query("SELECT status, result FROM jobs WHERE id = $1") + .bind(id) + .fetch_optional(self.pool.as_ref()) + .await + .context(RetrieveSnafu)?; + + match row { + Some(row) => { + let status: JobStatus = row.get("status"); + let status = status.to_string(); + match row.try_get("result") { + Ok(data) => { + let documents: Vec = + postcard::from_bytes(data).context(DeserializeSnafu)?; + Ok(Some((status, documents))) + } + Err(_) => Ok(Some((status, Vec::new()))), + } + } + None => Ok(None), + } + } +} diff --git a/src/storage/queries.rs b/src/storage/queries.rs new file mode 100644 index 0000000..65c8bc8 --- /dev/null +++ b/src/storage/queries.rs @@ -0,0 +1,211 @@ +use std::sync::Arc; + +use futures::future::try_join_all; +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt, Snafu}; +use sqlx::{FromRow, PgPool}; + +use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings}; + +#[derive(Debug, Clone)] +pub struct Queries { + /// Connection pool for database operations. + pool: Arc, +} + +/// Error when retrieving document version. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to get highest document version."))] +pub struct VersionError { + source: sqlx::Error, +} + +/// Error when storing embeddings. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to store embeddings."))] +pub struct EmbeddingError { + source: sqlx::Error, +} + +/// Error when querying for similar documents. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to query similarity data."))] +pub struct QueryError { + source: sqlx::Error, +} + +/// Error when updating document hash. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to update document hash."))] +pub struct UpdateHashError { + source: sqlx::Error, +} + +/// Error when checking if document needs update. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to check if document needs update."))] +pub struct CheckUpdateError { + source: sqlx::Error, +} + +/// Error when removing old document versions. +#[derive(Debug, Snafu)] +#[snafu(display("Failed to delete old document versions."))] +pub struct RemoveOldVersionsError { + source: sqlx::Error, +} + +/// A document chunk with similarity score from vector search. +#[derive(Debug, FromRow, Serialize, Deserialize)] +pub struct DocumentMatch { + /// Calibre book ID. + pub book_id: i64, + /// Text content of the chunk. + pub text_chunk: String, + /// Cosine similarity score (0.0 to 1.0). + pub similarity: f64, +} + +impl Queries { + pub fn new(pool: Arc) -> Self { + Self { pool } + } + + /// Get the highest version number for a book's embeddings. + pub async fn embedding_version(&self, book_id: i64) -> Result { + let version = sqlx::query_scalar( + "SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1", + ) + .bind(book_id) + .fetch_one(self.pool.as_ref()) + .await + .context(VersionSnafu)?; + + Ok(version) + } + + /// Store multiple embeddings for a book in parallel. + pub async fn add_multiple_embeddings( + &self, + book_id: i64, + version: i32, + embeddings: Vec, + ) -> Result<(), EmbeddingError> { + let inserts: Vec<_> = embeddings + .into_iter() + .map(|e| self.add_embeddings(book_id, version, e)) + .collect(); + + try_join_all(inserts).await?; + Ok(()) + } + + /// Store a single embedding for a book. + pub async fn add_embeddings( + &self, + book_id: i64, + version: i32, + embeddings: Embeddings, + ) -> Result<(), EmbeddingError> { + let vector_string = embeddings_to_vec_string(embeddings.0); + sqlx::query( + "INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)", + ) + .bind(book_id) + .bind(version) + .bind(vector_string) + .bind(embeddings.2) + .execute(self.pool.as_ref()) + .await + .context(EmbeddingSnafu)?; + + Ok(()) + } + + /// Find documents similar to the query embedding using cosine similarity. + pub async fn query( + &self, + query: Embeddings, + limit: i32, + ) -> Result, QueryError> { + let vector_string = embeddings_to_vec_string(query.0); + let books = sqlx::query_as::<_, DocumentMatch>( + "SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity + FROM documents + ORDER BY embedding <=> $1::vector + LIMIT $2", + ) + .bind(vector_string) + .bind(limit) + .fetch_all(self.pool.as_ref()) + .await + .context(QuerySnafu)?; + + Ok(books) + } + + /// Update or insert the hash for a book. + pub async fn update_hash( + &self, + book_id: i64, + hash: [u8; SHA256_LENGTH], + ) -> Result<(), UpdateHashError> { + sqlx::query( + "INSERT INTO hashes (book_id, hash) VALUES ($1, $2) + ON CONFLICT (book_id) DO UPDATE SET hash = $2", + ) + .bind(book_id) + .bind(&hash[..]) + .execute(self.pool.as_ref()) + .await + .context(UpdateHashSnafu)?; + + Ok(()) + } + + /// Check if a book needs to be reprocessed based on its hash. + pub async fn book_needs_update( + &self, + book_id: i64, + hash: [u8; SHA256_LENGTH], + ) -> Result { + let exists = sqlx::query_scalar::<_, bool>( + "SELECT NOT EXISTS ( + SELECT 1 FROM hashes + WHERE book_id = $1 AND hash = $2 + )", + ) + .bind(book_id) + .bind(&hash[..]) + .fetch_one(self.pool.as_ref()) + .await + .context(CheckUpdateSnafu)?; + + Ok(exists) + } + + /// Remove all but the latest version of embeddings for a book. + pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> { + sqlx::query( + "DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)", + ) + .bind(book_id) + .execute(self.pool.as_ref()) + .await + .context(RemoveOldVersionsSnafu)?; + + Ok(()) + } +} + +/// Convert a vector of floats to PostgreSQL vector format string. +fn embeddings_to_vec_string(vector: Vec) -> String { + format!( + "[{}]", + vector + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") + ) +} diff --git a/src/text_encoder.rs b/src/text_encoder.rs index 87f9335..8e4a5a1 100644 --- a/src/text_encoder.rs +++ b/src/text_encoder.rs @@ -12,7 +12,7 @@ use snafu::{ResultExt, Snafu}; use tract_onnx::prelude::*; use crate::{ - storage::DocumentMatch, + storage::queries::DocumentMatch, text_encoder::tract_ndarray::ShapeError, tokenize::{self, Encoding, Tokenizer}, };