WIP job queue

This commit is contained in:
Sebastian Hugentobler 2025-07-01 21:10:35 +02:00
parent 525e278a4e
commit 6a5b309391
Signed by: shu
SSH key fingerprint: SHA256:ppcx6MlixdNZd5EUM1nkHOKoyQYoJwzuQKXM6J/t66M
15 changed files with 685 additions and 256 deletions

92
Cargo.lock generated
View file

@ -158,6 +158,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]] [[package]]
name = "atomic-waker" name = "atomic-waker"
version = "1.1.2" version = "1.1.2"
@ -449,6 +458,15 @@ version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
[[package]]
name = "cobs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
dependencies = [
"thiserror",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.4" version = "1.0.4"
@ -543,6 +561,12 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "critical-section"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]] [[package]]
name = "crossbeam-deque" name = "crossbeam-deque"
version = "0.8.6" version = "0.8.6"
@ -842,6 +866,18 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]]
name = "embedded-io"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]] [[package]]
name = "encode_unicode" name = "encode_unicode"
version = "1.0.0" version = "1.0.0"
@ -1171,6 +1207,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.14.5" version = "0.14.5"
@ -1200,6 +1245,20 @@ dependencies = [
"hashbrown 0.15.3", "hashbrown 0.15.3",
] ]
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32",
"rustc_version",
"serde",
"spin",
"stable_deref_trait",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.5.0" version = "0.5.0"
@ -1788,6 +1847,7 @@ dependencies = [
"futures", "futures",
"ignore", "ignore",
"lopdf", "lopdf",
"postcard",
"quick-xml", "quick-xml",
"rayon", "rayon",
"ring", "ring",
@ -2411,6 +2471,19 @@ dependencies = [
"portable-atomic", "portable-atomic",
] ]
[[package]]
name = "postcard"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c1de96e20f51df24ca73cafcc4690e044854d803259db27a00a461cb3b9d17a"
dependencies = [
"cobs",
"embedded-io 0.4.0",
"embedded-io 0.6.1",
"heapless",
"serde",
]
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.2" version = "0.1.2"
@ -2745,6 +2818,15 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver",
]
[[package]] [[package]]
name = "rustfft" name = "rustfft"
version = "6.3.0" version = "6.3.0"
@ -2877,6 +2959,12 @@ dependencies = [
"smallvec", "smallvec",
] ]
[[package]]
name = "semver"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.219" version = "1.0.219"
@ -3125,6 +3213,7 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"url", "url",
"uuid",
"webpki-roots 0.26.11", "webpki-roots 0.26.11",
] ]
@ -3205,6 +3294,7 @@ dependencies = [
"stringprep", "stringprep",
"thiserror", "thiserror",
"tracing", "tracing",
"uuid",
"whoami", "whoami",
] ]
@ -3242,6 +3332,7 @@ dependencies = [
"stringprep", "stringprep",
"thiserror", "thiserror",
"tracing", "tracing",
"uuid",
"whoami", "whoami",
] ]
@ -3267,6 +3358,7 @@ dependencies = [
"thiserror", "thiserror",
"tracing", "tracing",
"url", "url",
"uuid",
] ]
[[package]] [[package]]

View file

@ -9,6 +9,7 @@ axum = { version = "0.8.4", features = ["http2", "tracing"] }
clap = { version = "4.5.40", features = ["derive"] } clap = { version = "4.5.40", features = ["derive"] }
futures = "0.3.31" futures = "0.3.31"
lopdf = "0.36.0" lopdf = "0.36.0"
postcard = { version = "1.1.2", features = ["alloc"] }
quick-xml = "0.38.0" quick-xml = "0.38.0"
rayon = "1.10.0" rayon = "1.10.0"
ring = "0.17.14" ring = "0.17.14"
@ -16,7 +17,7 @@ scraper = "0.23.1"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
sha2 = "0.10.9" sha2 = "0.10.9"
snafu = { version = "0.8.6", features = ["rust_1_81"] } snafu = { version = "0.8.6", features = ["rust_1_81"] }
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] } sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive", "uuid", "macros" ] }
tokenizers = "0.21.2" tokenizers = "0.21.2"
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]} tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
tower-http = { version = "0.6.6", features = ["trace"] } tower-http = { version = "0.6.6", features = ["trace"] }

View file

@ -0,0 +1,12 @@
CREATE TYPE job_status AS ENUM ('running', 'completed', 'failed');
CREATE TABLE jobs (
id UUID PRIMARY KEY,
status job_status NOT NULL DEFAULT 'running',
error TEXT,
result BYTEA,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE INDEX idx_jobs_status ON jobs(status);
CREATE INDEX idx_jobs_expires_at ON jobs(expires_at);

View file

@ -4,8 +4,10 @@ use std::{
io, io,
net::{AddrParseError, SocketAddr}, net::{AddrParseError, SocketAddr},
str::FromStr, str::FromStr,
sync::Arc,
}; };
use jobs::JobManager;
use snafu::{ResultExt, Snafu}; use snafu::{ResultExt, Snafu};
use state::AppState; use state::AppState;
use tokio::net::TcpListener; use tokio::net::TcpListener;
@ -17,6 +19,7 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
pub mod code; pub mod code;
pub mod error; pub mod error;
pub mod jobs;
pub mod query; pub mod query;
pub mod routes; pub mod routes;
pub mod state; pub mod state;
@ -55,12 +58,14 @@ pub async fn serve(
reranker: TextEncoder, reranker: TextEncoder,
chunk_size: usize, chunk_size: usize,
) -> Result<(), ServeError> { ) -> Result<(), ServeError> {
let jobs = JobManager::new(Arc::new(db.jobs.clone()));
let state = AppState { let state = AppState {
db, db,
tokenizer, tokenizer,
embedder, embedder,
reranker, reranker,
chunk_size, chunk_size,
jobs,
}; };
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())

81
src/api/jobs.rs Normal file
View file

@ -0,0 +1,81 @@
use snafu::Report;
use snafu::{ResultExt, Snafu};
use std::{sync::Arc, time::Duration};
use tracing::error;
use uuid::Uuid;
use crate::storage;
use crate::storage::jobs::AddError;
#[derive(Clone, Debug)]
pub struct JobManager {
db: Arc<storage::jobs::Jobs>,
}
/// Errors that occur when running a background job.
#[derive(Debug, Snafu)]
pub enum ExecuteError {
#[snafu(display("Failed to add new background job."))]
Add { source: AddError },
}
impl JobManager {
pub fn new(db: Arc<storage::jobs::Jobs>) -> Self {
Self { db }
}
pub async fn execute<F, Fut, T, E>(&self, task: F) -> Result<Uuid, ExecuteError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send,
T: serde::Serialize + Send + Sync + 'static,
E: snafu::Error + Send + Sync + 'static,
{
let job_id = Uuid::new_v4();
self.db.as_ref().add(&job_id).await.context(AddSnafu)?;
let db = self.db.clone();
tokio::spawn(async move {
match task().await {
Ok(result) => success(&job_id, &result, &db).await,
Err(error) => failure(&job_id, error, &db).await,
}
tokio::time::sleep(Duration::from_secs(300)).await;
});
Ok(job_id)
}
}
async fn success<T>(job_id: &Uuid, result: &T, db: &storage::jobs::Jobs)
where
T: serde::Serialize + Send + Sync + 'static,
{
let data = match postcard::to_allocvec(result) {
Ok(data) => data,
Err(error) => {
failure(job_id, error, db).await;
return;
}
};
if let Err(error) = db.finish(job_id, &data).await {
failure(job_id, error, db).await;
}
}
async fn failure<E>(job_id: &Uuid, error: E, db: &storage::jobs::Jobs)
where
E: snafu::Error + Send + Sync + 'static,
{
let error = Report::from_error(&error);
error!("{job_id}: {error}");
if let Err(error) = db.fail(job_id, &error.to_string()).await {
error!(
"Failed to save job {job_id} result: {}",
Report::from_error(&error)
);
}
}

View file

@ -1,18 +1,22 @@
//! Query endpoint handlers and response types. //! Query endpoint handlers and response types.
use std::sync::Arc; use std::{str::FromStr, sync::Arc};
use axum::{ use axum::{
Json, Json,
extract::{Query, State}, extract::{Path, Query, State},
http::StatusCode, http::StatusCode,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu, ensure}; use snafu::{ResultExt, Snafu, ensure};
use utoipa::ToSchema; use utoipa::ToSchema;
use uuid::Uuid;
use super::{TAG, error::HttpStatus, state::AppState}; use super::{TAG, error::HttpStatus, jobs, state::AppState};
use crate::{http_error, query, storage::DocumentMatch}; use crate::{
http_error, query,
storage::{self, queries::DocumentMatch},
};
const MAX_LIMIT: usize = 10; const MAX_LIMIT: usize = 10;
@ -22,14 +26,14 @@ pub enum QueryError {
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))] #[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
Limit, Limit,
#[snafu(display("Failed to run query."))] #[snafu(display("Failed to run query."))]
Query { source: query::AskError }, QueryExecute { source: jobs::ExecuteError },
} }
impl HttpStatus for QueryError { impl HttpStatus for QueryError {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match self { match self {
QueryError::Limit => StatusCode::BAD_REQUEST, QueryError::Limit => StatusCode::BAD_REQUEST,
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, QueryError::QueryExecute { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }
} }
@ -46,25 +50,9 @@ pub struct QueryParams {
/// Response format for successful query requests. /// Response format for successful query requests.
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct QueryResponse { pub struct QueryResponse {
pub status: String,
/// List of matching document chunks. /// List of matching document chunks.
pub results: Vec<DocumentResult>, pub results: Vec<DocumentResult>,
/// Total number of results returned.
pub count: usize,
/// Original query text that was searched.
pub query: String,
}
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
let count = results.len();
Self {
results,
count,
query,
}
}
} }
/// A single document search result. /// A single document search result.
@ -88,6 +76,13 @@ impl From<DocumentMatch> for DocumentResult {
} }
} }
/// Response format for successful query requests.
#[derive(Debug, Serialize, ToSchema)]
pub struct QueryStartResponse {
///.Job id.
pub id: String,
}
/// Execute a semantic search query against the document database. /// Execute a semantic search query against the document database.
#[utoipa::path( #[utoipa::path(
post, post,
@ -98,31 +93,84 @@ impl From<DocumentMatch> for DocumentResult {
), ),
request_body = String, request_body = String,
responses( responses(
(status = OK, body = QueryResponse), (status = 202, body = QueryStartResponse),
(status = 400, description = "Wrong parameter."), (status = 400, description = "Wrong parameter."),
(status = 500, description = "Failure to query database.") (status = 500, description = "Failure to query database.")
) )
)] )]
pub async fn route( pub async fn start(
Query(params): Query<QueryParams>, Query(params): Query<QueryParams>,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
body: String, body: String,
) -> Result<Json<QueryResponse>, QueryError> { ) -> Result<Json<QueryStartResponse>, QueryError> {
let limit = params.limit.unwrap_or(5); let limit = params.limit.unwrap_or(5);
ensure!(limit <= MAX_LIMIT, LimitSnafu); ensure!(limit <= MAX_LIMIT, LimitSnafu);
let results = query::ask( let jobs = state.jobs.clone();
&body, let id = jobs
&state.db, .execute(move || async move {
&state.tokenizer, query::ask(
&state.embedder, &body,
&state.reranker, &state.db.queries,
state.chunk_size, &state.tokenizer,
limit, &state.embedder,
) &state.reranker,
.await state.chunk_size,
.context(QuerySnafu)?; limit,
let response = QueryResponse::from((results, body)); )
.await
})
.await
.context(QueryExecuteSnafu)?;
Ok(Json(response)) Ok(Json(QueryStartResponse { id: id.to_string() }))
}
/// Errors that occur during query retrieval.
#[derive(Debug, Snafu)]
pub enum RetrieveError {
#[snafu(display("Failed to retrieve job from database."))]
Db {
source: storage::jobs::RetrieveError,
},
#[snafu(display("No job with id {id}."))]
NoJob { id: Uuid },
#[snafu(display("'{id}' is not a valid v4 UUID."))]
Uuid { source: uuid::Error, id: String },
}
impl HttpStatus for RetrieveError {
fn status_code(&self) -> StatusCode {
match self {
RetrieveError::Db { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
RetrieveError::NoJob { id: _ } => StatusCode::NOT_FOUND,
RetrieveError::Uuid { source: _, id: _ } => StatusCode::BAD_REQUEST,
}
}
}
http_error!(RetrieveError);
#[utoipa::path(
get,
path = "/query/{id}",
tag = TAG,
responses(
(status = OK, body = QueryResponse),
(status = 404, description = "No job with the requested id."),
(status = 500, description = "Failure to query database.")
)
)]
pub async fn result(
Path(id): Path<String>,
State(state): State<Arc<AppState>>,
) -> Result<Json<QueryResponse>, RetrieveError> {
let id = Uuid::from_str(&id).context(UuidSnafu { id })?;
match state.db.jobs.retrieve(&id).await.context(DbSnafu)? {
Some((status, results)) => Ok(Json(QueryResponse {
status,
results: results.into_iter().map(Into::into).collect(),
})),
None => NoJobSnafu { id }.fail(),
}
} }

View file

@ -12,7 +12,8 @@ use crate::api::{code, query};
pub fn router(state: AppState) -> OpenApiRouter { pub fn router(state: AppState) -> OpenApiRouter {
let store = Arc::new(state); let store = Arc::new(state);
OpenApiRouter::new() OpenApiRouter::new()
.routes(routes!(query::route, code::route)) .routes(routes!(query::start, query::result))
.routes(routes!(code::route))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.with_state(store) .with_state(store)
} }

View file

@ -2,6 +2,8 @@
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer}; use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
use super::jobs::JobManager;
/// Application state shared across all HTTP request handlers. /// Application state shared across all HTTP request handlers.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppState { pub struct AppState {
@ -15,4 +17,6 @@ pub struct AppState {
pub reranker: TextEncoder, pub reranker: TextEncoder,
/// Text chunk size in words for processing. /// Text chunk size in words for processing.
pub chunk_size: usize, pub chunk_size: usize,
/// Background jobs handling.
pub jobs: JobManager,
} }

View file

@ -75,7 +75,7 @@ async fn main() -> Result<(), Error> {
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?; TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
query::ask( query::ask(
&config.query, &config.query,
&db, &db.queries,
&tokenizer, &tokenizer,
&embedder, &embedder,
&reranker, &reranker,
@ -87,7 +87,7 @@ async fn main() -> Result<(), Error> {
} }
Commands::Scan(config) => { Commands::Scan(config) => {
scanner::scan( scanner::scan(
&db, &db.queries,
&config.library_path, &config.library_path,
&tokenizer, &tokenizer,
&embedder, &embedder,

View file

@ -3,7 +3,10 @@
use snafu::{ResultExt, Snafu}; use snafu::{ResultExt, Snafu};
use crate::{ use crate::{
storage::{self, DocumentMatch, Postgres}, storage::{
self,
queries::{DocumentMatch, Queries},
},
text_encoder::{self, TextEncoder}, text_encoder::{self, TextEncoder},
tokenize::{self, Tokenizer}, tokenize::{self, Tokenizer},
}; };
@ -16,7 +19,9 @@ pub enum AskError {
#[snafu(display("Failed to embed query."))] #[snafu(display("Failed to embed query."))]
Embed { source: text_encoder::EmbedError }, Embed { source: text_encoder::EmbedError },
#[snafu(display("Failed to retrieve similar documents."))] #[snafu(display("Failed to retrieve similar documents."))]
Query { source: storage::QueryError }, Query {
source: storage::queries::QueryError,
},
#[snafu(display("Failed to rerank documents."))] #[snafu(display("Failed to rerank documents."))]
Rerank { source: text_encoder::RerankError }, Rerank { source: text_encoder::RerankError },
} }
@ -24,7 +29,7 @@ pub enum AskError {
/// Process a user query and return ranked document matches. /// Process a user query and return ranked document matches.
pub async fn ask( pub async fn ask(
query: &str, query: &str,
db: &Postgres, db: &Queries,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
embedder: &TextEncoder, embedder: &TextEncoder,
reranker: &TextEncoder, reranker: &TextEncoder,

View file

@ -9,14 +9,14 @@ use crate::{
calibre::{Book, BookIterator}, calibre::{Book, BookIterator},
extractors::{self, extractor::ExtractionError}, extractors::{self, extractor::ExtractionError},
hash, hash,
storage::{self, Postgres}, storage::{self, queries::Queries},
text_encoder::{self, TextEncoder}, text_encoder::{self, TextEncoder},
tokenize::{self, Tokenizer}, tokenize::{self, Tokenizer},
}; };
/// Scan a Calibre library and process all books for embedding generation. /// Scan a Calibre library and process all books for embedding generation.
pub async fn scan( pub async fn scan(
db: &Postgres, db: &Queries,
library_path: &PathBuf, library_path: &PathBuf,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
text_encoder: &TextEncoder, text_encoder: &TextEncoder,
@ -43,7 +43,9 @@ pub enum ProcessBookError {
#[snafu(display("Failed to hash book."))] #[snafu(display("Failed to hash book."))]
BookHash { source: io::Error }, BookHash { source: io::Error },
#[snafu(display("Failed to check in db if document needs update."))] #[snafu(display("Failed to check in db if document needs update."))]
UpdateCheck { source: storage::CheckUpdateError }, UpdateCheck {
source: storage::queries::CheckUpdateError,
},
#[snafu(display("Failed to open book."))] #[snafu(display("Failed to open book."))]
Open { source: io::Error }, Open { source: io::Error },
#[snafu(display("Failed to extract document content."))] #[snafu(display("Failed to extract document content."))]
@ -53,21 +55,27 @@ pub enum ProcessBookError {
#[snafu(display("Failed to create embeddings from encodings."))] #[snafu(display("Failed to create embeddings from encodings."))]
Embedding { source: text_encoder::EmbedError }, Embedding { source: text_encoder::EmbedError },
#[snafu(display("Failed to look up document versions."))] #[snafu(display("Failed to look up document versions."))]
DocumentVersion { source: storage::VersionError }, DocumentVersion {
source: storage::queries::VersionError,
},
#[snafu(display("Failed to store embeddings in database."))] #[snafu(display("Failed to store embeddings in database."))]
SaveEmbeddings { source: storage::EmbeddingError }, SaveEmbeddings {
source: storage::queries::EmbeddingError,
},
#[snafu(display("Failed to update document hash in database."))] #[snafu(display("Failed to update document hash in database."))]
UpdateHash { source: storage::UpdateHashError }, UpdateHash {
source: storage::queries::UpdateHashError,
},
#[snafu(display("Failed to delete old document versions in database."))] #[snafu(display("Failed to delete old document versions in database."))]
RemoveOldVersions { RemoveOldVersions {
source: storage::RemoveOldVersionsError, source: storage::queries::RemoveOldVersionsError,
}, },
} }
/// Process a single book: extract text, generate embeddings, and store in database. /// Process a single book: extract text, generate embeddings, and store in database.
async fn process_book( async fn process_book(
book: &Book, book: &Book,
db: &Postgres, db: &Queries,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
embedder: &TextEncoder, embedder: &TextEncoder,
chunk_size: usize, chunk_size: usize,

View file

@ -1,27 +1,20 @@
//! Database storage and retrieval for document embeddings. //! Database storage and retrieval for document embeddings.
use futures::future::try_join_all; use std::sync::Arc;
use jobs::Jobs;
use queries::Queries;
use snafu::{ResultExt, Snafu}; use snafu::{ResultExt, Snafu};
use sqlx::{FromRow, Pool, postgres::PgPoolOptions}; use sqlx::postgres::PgPoolOptions;
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings}; pub mod jobs;
pub mod queries;
/// A document chunk with similarity score from vector search.
#[derive(Debug, FromRow)]
pub struct DocumentMatch {
/// Calibre book ID.
pub book_id: i64,
/// Text content of the chunk.
pub text_chunk: String,
/// Cosine similarity score (0.0 to 1.0).
pub similarity: f64,
}
/// PostgreSQL database connection pool and operations. /// PostgreSQL database connection pool and operations.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Postgres { pub struct Postgres {
/// Connection pool for database operations. pub jobs: Jobs,
pool: Pool<sqlx::Postgres>, pub queries: Queries,
} }
/// Error when connecting to database. /// Error when connecting to database.
@ -31,195 +24,20 @@ pub struct ConnectionError {
source: sqlx::Error, source: sqlx::Error,
} }
/// Error when retrieving document version.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to get highest document version."))]
pub struct VersionError {
source: sqlx::Error,
}
/// Error when storing embeddings.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to store embeddings."))]
pub struct EmbeddingError {
source: sqlx::Error,
}
/// Error when querying for similar documents.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to query similarity data."))]
pub struct QueryError {
source: sqlx::Error,
}
/// Error when updating document hash.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to update document hash."))]
pub struct UpdateHashError {
source: sqlx::Error,
}
/// Error when checking if document needs update.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to check if document needs update."))]
pub struct CheckUpdateError {
source: sqlx::Error,
}
/// Error when removing old document versions.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to delete old document versions."))]
pub struct RemoveOldVersionsError {
source: sqlx::Error,
}
impl Postgres { impl Postgres {
/// Create a new database connection pool. /// Create a new database connection pool.
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> { pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
let pool = PgPoolOptions::new() let pool = Arc::new(
.max_connections(max_connections) PgPoolOptions::new()
.connect(db_url) .max_connections(max_connections)
.await .connect(db_url)
.context(ConnectionSnafu)?; .await
.context(ConnectionSnafu)?,
);
Ok(Self { pool }) let jobs = Jobs::new(pool.clone());
} let queries = Queries::new(pool.clone());
/// Get the highest version number for a book's embeddings. Ok(Self { jobs, queries })
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
let version = sqlx::query_scalar(
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
)
.bind(book_id)
.fetch_one(&self.pool)
.await
.context(VersionSnafu)?;
Ok(version)
}
/// Store multiple embeddings for a book in parallel.
pub async fn add_multiple_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Vec<Embeddings>,
) -> Result<(), EmbeddingError> {
let inserts: Vec<_> = embeddings
.into_iter()
.map(|e| self.add_embeddings(book_id, version, e))
.collect();
try_join_all(inserts).await?;
Ok(())
}
/// Store a single embedding for a book.
pub async fn add_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Embeddings,
) -> Result<(), EmbeddingError> {
let vector_string = embeddings_to_vec_string(embeddings.0);
sqlx::query(
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
)
.bind(book_id)
.bind(version)
.bind(vector_string)
.bind(embeddings.2)
.execute(&self.pool)
.await
.context(EmbeddingSnafu)?;
Ok(())
}
/// Find documents similar to the query embedding using cosine similarity.
pub async fn query(
&self,
query: Embeddings,
limit: i32,
) -> Result<Vec<DocumentMatch>, QueryError> {
let vector_string = embeddings_to_vec_string(query.0);
let books = sqlx::query_as::<_, DocumentMatch>(
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
FROM documents
ORDER BY embedding <=> $1::vector
LIMIT $2",
)
.bind(vector_string)
.bind(limit)
.fetch_all(&self.pool)
.await
.context(QuerySnafu)?;
Ok(books)
}
/// Update or insert the hash for a book.
pub async fn update_hash(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<(), UpdateHashError> {
sqlx::query(
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
)
.bind(book_id)
.bind(&hash[..])
.execute(&self.pool)
.await
.context(UpdateHashSnafu)?;
Ok(())
}
/// Check if a book needs to be reprocessed based on its hash.
pub async fn book_needs_update(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<bool, CheckUpdateError> {
let exists = sqlx::query_scalar::<_, bool>(
"SELECT NOT EXISTS (
SELECT 1 FROM hashes
WHERE book_id = $1 AND hash = $2
)",
)
.bind(book_id)
.bind(&hash[..])
.fetch_one(&self.pool)
.await
.context(CheckUpdateSnafu)?;
Ok(exists)
}
/// Remove all but the latest version of embeddings for a book.
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
sqlx::query(
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
)
.bind(book_id)
.execute(&self.pool)
.await
.context(RemoveOldVersionsSnafu)?;
Ok(())
} }
} }
/// Convert a vector of floats to PostgreSQL vector format string.
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
format!(
"[{}]",
vector
.iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join(",")
)
}

143
src/storage/jobs.rs Normal file
View file

@ -0,0 +1,143 @@
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use sqlx::{PgPool, Row, Type};
use uuid::Uuid;
use super::queries::DocumentMatch;
#[derive(Debug, Clone)]
pub struct Jobs {
/// Connection pool for database operations.
pool: Arc<PgPool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Type)]
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
pub enum JobStatus {
Running,
Completed,
Failed,
}
impl std::fmt::Display for JobStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JobStatus::Running => write!(f, "running"),
JobStatus::Completed => write!(f, "completed"),
JobStatus::Failed => write!(f, "failed"),
}
}
}
impl std::str::FromStr for JobStatus {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"running" => Ok(JobStatus::Running),
"completed" => Ok(JobStatus::Completed),
"failed" => Ok(JobStatus::Failed),
_ => Err(format!("Invalid job status: {}", s)),
}
}
}
/// Error adding a new job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add background job."))]
pub struct AddError {
source: sqlx::Error,
}
/// Error when adding a result to a job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add result to background job."))]
pub struct ResultError {
source: sqlx::Error,
}
/// Error when adding a result to a job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add failure to background job."))]
pub struct FailureError {
source: sqlx::Error,
}
/// Errors that occur durin job retrieval.
#[derive(Debug, Snafu)]
pub enum RetrieveError {
#[snafu(display("Failed to retrieve background job data"))]
Retrieve { source: sqlx::Error },
#[snafu(display("Failed to run query."))]
Deserialize { source: postcard::Error },
}
impl Jobs {
pub fn new(pool: Arc<PgPool>) -> Self {
Self { pool }
}
/// Add a new job to the database and mark it as running.
pub async fn add(&self, id: &Uuid) -> Result<(), AddError> {
sqlx::query("INSERT INTO jobs (id, expires_at) VALUES ($1, NOW() + INTERVAL '24 hours')")
.bind(id)
.execute(self.pool.as_ref())
.await
.context(AddSnafu)?;
Ok(())
}
/// Add a result to a job and mark it as finished.
pub async fn finish(&self, id: &Uuid, data: &[u8]) -> Result<(), ResultError> {
sqlx::query("UPDATE jobs SET status = 'completed', result = $1 WHERE id = $2")
.bind(data)
.bind(id)
.execute(self.pool.as_ref())
.await
.context(ResultSnafu)?;
Ok(())
}
/// Add an error to a job and mark it as failed.
pub async fn fail(&self, id: &Uuid, error: &str) -> Result<(), FailureError> {
sqlx::query("UPDATE jobs SET status = 'failed', error = $1 WHERE id = $2")
.bind(error)
.bind(id)
.execute(self.pool.as_ref())
.await
.context(FailureSnafu)?;
Ok(())
}
pub async fn retrieve(
&self,
id: &Uuid,
) -> Result<Option<(String, Vec<DocumentMatch>)>, RetrieveError> {
let row = sqlx::query("SELECT status, result FROM jobs WHERE id = $1")
.bind(id)
.fetch_optional(self.pool.as_ref())
.await
.context(RetrieveSnafu)?;
match row {
Some(row) => {
let status: JobStatus = row.get("status");
let status = status.to_string();
match row.try_get("result") {
Ok(data) => {
let documents: Vec<DocumentMatch> =
postcard::from_bytes(data).context(DeserializeSnafu)?;
Ok(Some((status, documents)))
}
Err(_) => Ok(Some((status, Vec::new()))),
}
}
None => Ok(None),
}
}
}

211
src/storage/queries.rs Normal file
View file

@ -0,0 +1,211 @@
use std::sync::Arc;
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use sqlx::{FromRow, PgPool};
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
#[derive(Debug, Clone)]
pub struct Queries {
/// Connection pool for database operations.
pool: Arc<PgPool>,
}
/// Error when retrieving document version.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to get highest document version."))]
pub struct VersionError {
source: sqlx::Error,
}
/// Error when storing embeddings.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to store embeddings."))]
pub struct EmbeddingError {
source: sqlx::Error,
}
/// Error when querying for similar documents.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to query similarity data."))]
pub struct QueryError {
source: sqlx::Error,
}
/// Error when updating document hash.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to update document hash."))]
pub struct UpdateHashError {
source: sqlx::Error,
}
/// Error when checking if document needs update.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to check if document needs update."))]
pub struct CheckUpdateError {
source: sqlx::Error,
}
/// Error when removing old document versions.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to delete old document versions."))]
pub struct RemoveOldVersionsError {
source: sqlx::Error,
}
/// A document chunk with similarity score from vector search.
#[derive(Debug, FromRow, Serialize, Deserialize)]
pub struct DocumentMatch {
/// Calibre book ID.
pub book_id: i64,
/// Text content of the chunk.
pub text_chunk: String,
/// Cosine similarity score (0.0 to 1.0).
pub similarity: f64,
}
impl Queries {
pub fn new(pool: Arc<PgPool>) -> Self {
Self { pool }
}
/// Get the highest version number for a book's embeddings.
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
let version = sqlx::query_scalar(
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
)
.bind(book_id)
.fetch_one(self.pool.as_ref())
.await
.context(VersionSnafu)?;
Ok(version)
}
/// Store multiple embeddings for a book in parallel.
pub async fn add_multiple_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Vec<Embeddings>,
) -> Result<(), EmbeddingError> {
let inserts: Vec<_> = embeddings
.into_iter()
.map(|e| self.add_embeddings(book_id, version, e))
.collect();
try_join_all(inserts).await?;
Ok(())
}
/// Store a single embedding for a book.
pub async fn add_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Embeddings,
) -> Result<(), EmbeddingError> {
let vector_string = embeddings_to_vec_string(embeddings.0);
sqlx::query(
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
)
.bind(book_id)
.bind(version)
.bind(vector_string)
.bind(embeddings.2)
.execute(self.pool.as_ref())
.await
.context(EmbeddingSnafu)?;
Ok(())
}
/// Find documents similar to the query embedding using cosine similarity.
pub async fn query(
&self,
query: Embeddings,
limit: i32,
) -> Result<Vec<DocumentMatch>, QueryError> {
let vector_string = embeddings_to_vec_string(query.0);
let books = sqlx::query_as::<_, DocumentMatch>(
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
FROM documents
ORDER BY embedding <=> $1::vector
LIMIT $2",
)
.bind(vector_string)
.bind(limit)
.fetch_all(self.pool.as_ref())
.await
.context(QuerySnafu)?;
Ok(books)
}
/// Update or insert the hash for a book.
pub async fn update_hash(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<(), UpdateHashError> {
sqlx::query(
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
)
.bind(book_id)
.bind(&hash[..])
.execute(self.pool.as_ref())
.await
.context(UpdateHashSnafu)?;
Ok(())
}
/// Check if a book needs to be reprocessed based on its hash.
pub async fn book_needs_update(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<bool, CheckUpdateError> {
let exists = sqlx::query_scalar::<_, bool>(
"SELECT NOT EXISTS (
SELECT 1 FROM hashes
WHERE book_id = $1 AND hash = $2
)",
)
.bind(book_id)
.bind(&hash[..])
.fetch_one(self.pool.as_ref())
.await
.context(CheckUpdateSnafu)?;
Ok(exists)
}
/// Remove all but the latest version of embeddings for a book.
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
sqlx::query(
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
)
.bind(book_id)
.execute(self.pool.as_ref())
.await
.context(RemoveOldVersionsSnafu)?;
Ok(())
}
}
/// Convert a vector of floats to PostgreSQL vector format string.
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
format!(
"[{}]",
vector
.iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join(",")
)
}

View file

@ -12,7 +12,7 @@ use snafu::{ResultExt, Snafu};
use tract_onnx::prelude::*; use tract_onnx::prelude::*;
use crate::{ use crate::{
storage::DocumentMatch, storage::queries::DocumentMatch,
text_encoder::tract_ndarray::ShapeError, text_encoder::tract_ndarray::ShapeError,
tokenize::{self, Encoding, Tokenizer}, tokenize::{self, Encoding, Tokenizer},
}; };