WIP job queue

This commit is contained in:
Sebastian Hugentobler 2025-07-01 21:10:35 +02:00
parent 525e278a4e
commit 6a5b309391
Signed by: shu
SSH key fingerprint: SHA256:ppcx6MlixdNZd5EUM1nkHOKoyQYoJwzuQKXM6J/t66M
15 changed files with 685 additions and 256 deletions

92
Cargo.lock generated
View file

@ -158,6 +158,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@ -449,6 +458,15 @@ version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
[[package]]
name = "cobs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
dependencies = [
"thiserror",
]
[[package]]
name = "colorchoice"
version = "1.0.4"
@ -543,6 +561,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "critical-section"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
@ -842,6 +866,18 @@ dependencies = [
"serde",
]
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]]
name = "embedded-io"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]]
name = "encode_unicode"
version = "1.0.0"
@ -1171,6 +1207,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@ -1200,6 +1245,20 @@ dependencies = [
"hashbrown 0.15.3",
]
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32",
"rustc_version",
"serde",
"spin",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -1788,6 +1847,7 @@ dependencies = [
"futures",
"ignore",
"lopdf",
"postcard",
"quick-xml",
"rayon",
"ring",
@ -2411,6 +2471,19 @@ dependencies = [
"portable-atomic",
]
[[package]]
name = "postcard"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c1de96e20f51df24ca73cafcc4690e044854d803259db27a00a461cb3b9d17a"
dependencies = [
"cobs",
"embedded-io 0.4.0",
"embedded-io 0.6.1",
"heapless",
"serde",
]
[[package]]
name = "potential_utf"
version = "0.1.2"
@ -2745,6 +2818,15 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver",
]
[[package]]
name = "rustfft"
version = "6.3.0"
@ -2877,6 +2959,12 @@ dependencies = [
"smallvec",
]
[[package]]
name = "semver"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
[[package]]
name = "serde"
version = "1.0.219"
@ -3125,6 +3213,7 @@ dependencies = [
"tokio-stream",
"tracing",
"url",
"uuid",
"webpki-roots 0.26.11",
]
@ -3205,6 +3294,7 @@ dependencies = [
"stringprep",
"thiserror",
"tracing",
"uuid",
"whoami",
]
@ -3242,6 +3332,7 @@ dependencies = [
"stringprep",
"thiserror",
"tracing",
"uuid",
"whoami",
]
@ -3267,6 +3358,7 @@ dependencies = [
"thiserror",
"tracing",
"url",
"uuid",
]
[[package]]

View file

@ -9,6 +9,7 @@ axum = { version = "0.8.4", features = ["http2", "tracing"] }
clap = { version = "4.5.40", features = ["derive"] }
futures = "0.3.31"
lopdf = "0.36.0"
postcard = { version = "1.1.2", features = ["alloc"] }
quick-xml = "0.38.0"
rayon = "1.10.0"
ring = "0.17.14"
@ -16,7 +17,7 @@ scraper = "0.23.1"
serde = { version = "1.0.219", features = ["derive"] }
sha2 = "0.10.9"
snafu = { version = "0.8.6", features = ["rust_1_81"] }
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] }
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive", "uuid", "macros" ] }
tokenizers = "0.21.2"
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
tower-http = { version = "0.6.6", features = ["trace"] }

View file

@ -0,0 +1,12 @@
CREATE TYPE job_status AS ENUM ('running', 'completed', 'failed');
CREATE TABLE jobs (
id UUID PRIMARY KEY,
status job_status NOT NULL DEFAULT 'running',
error TEXT,
result BYTEA,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE INDEX idx_jobs_status ON jobs(status);
CREATE INDEX idx_jobs_expires_at ON jobs(expires_at);

View file

@ -4,8 +4,10 @@ use std::{
io,
net::{AddrParseError, SocketAddr},
str::FromStr,
sync::Arc,
};
use jobs::JobManager;
use snafu::{ResultExt, Snafu};
use state::AppState;
use tokio::net::TcpListener;
@ -17,6 +19,7 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
pub mod code;
pub mod error;
pub mod jobs;
pub mod query;
pub mod routes;
pub mod state;
@ -55,12 +58,14 @@ pub async fn serve(
reranker: TextEncoder,
chunk_size: usize,
) -> Result<(), ServeError> {
let jobs = JobManager::new(Arc::new(db.jobs.clone()));
let state = AppState {
db,
tokenizer,
embedder,
reranker,
chunk_size,
jobs,
};
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())

81
src/api/jobs.rs Normal file
View file

@ -0,0 +1,81 @@
use snafu::Report;
use snafu::{ResultExt, Snafu};
use std::{sync::Arc, time::Duration};
use tracing::error;
use uuid::Uuid;
use crate::storage;
use crate::storage::jobs::AddError;
#[derive(Clone, Debug)]
pub struct JobManager {
db: Arc<storage::jobs::Jobs>,
}
/// Errors that occur when running a background job.
#[derive(Debug, Snafu)]
pub enum ExecuteError {
#[snafu(display("Failed to add new background job."))]
Add { source: AddError },
}
impl JobManager {
pub fn new(db: Arc<storage::jobs::Jobs>) -> Self {
Self { db }
}
pub async fn execute<F, Fut, T, E>(&self, task: F) -> Result<Uuid, ExecuteError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send,
T: serde::Serialize + Send + Sync + 'static,
E: snafu::Error + Send + Sync + 'static,
{
let job_id = Uuid::new_v4();
self.db.as_ref().add(&job_id).await.context(AddSnafu)?;
let db = self.db.clone();
tokio::spawn(async move {
match task().await {
Ok(result) => success(&job_id, &result, &db).await,
Err(error) => failure(&job_id, error, &db).await,
}
tokio::time::sleep(Duration::from_secs(300)).await;
});
Ok(job_id)
}
}
async fn success<T>(job_id: &Uuid, result: &T, db: &storage::jobs::Jobs)
where
T: serde::Serialize + Send + Sync + 'static,
{
let data = match postcard::to_allocvec(result) {
Ok(data) => data,
Err(error) => {
failure(job_id, error, db).await;
return;
}
};
if let Err(error) = db.finish(job_id, &data).await {
failure(job_id, error, db).await;
}
}
async fn failure<E>(job_id: &Uuid, error: E, db: &storage::jobs::Jobs)
where
E: snafu::Error + Send + Sync + 'static,
{
let error = Report::from_error(&error);
error!("{job_id}: {error}");
if let Err(error) = db.fail(job_id, &error.to_string()).await {
error!(
"Failed to save job {job_id} result: {}",
Report::from_error(&error)
);
}
}

View file

@ -1,18 +1,22 @@
//! Query endpoint handlers and response types.
use std::sync::Arc;
use std::{str::FromStr, sync::Arc};
use axum::{
Json,
extract::{Query, State},
extract::{Path, Query, State},
http::StatusCode,
};
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu, ensure};
use utoipa::ToSchema;
use uuid::Uuid;
use super::{TAG, error::HttpStatus, state::AppState};
use crate::{http_error, query, storage::DocumentMatch};
use super::{TAG, error::HttpStatus, jobs, state::AppState};
use crate::{
http_error, query,
storage::{self, queries::DocumentMatch},
};
const MAX_LIMIT: usize = 10;
@ -22,14 +26,14 @@ pub enum QueryError {
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
Limit,
#[snafu(display("Failed to run query."))]
Query { source: query::AskError },
QueryExecute { source: jobs::ExecuteError },
}
impl HttpStatus for QueryError {
fn status_code(&self) -> StatusCode {
match self {
QueryError::Limit => StatusCode::BAD_REQUEST,
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
QueryError::QueryExecute { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
@ -46,25 +50,9 @@ pub struct QueryParams {
/// Response format for successful query requests.
#[derive(Debug, Serialize, ToSchema)]
pub struct QueryResponse {
pub status: String,
/// List of matching document chunks.
pub results: Vec<DocumentResult>,
/// Total number of results returned.
pub count: usize,
/// Original query text that was searched.
pub query: String,
}
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
let count = results.len();
Self {
results,
count,
query,
}
}
}
/// A single document search result.
@ -88,6 +76,13 @@ impl From<DocumentMatch> for DocumentResult {
}
}
/// Response format for successful query requests.
#[derive(Debug, Serialize, ToSchema)]
pub struct QueryStartResponse {
///.Job id.
pub id: String,
}
/// Execute a semantic search query against the document database.
#[utoipa::path(
post,
@ -98,31 +93,84 @@ impl From<DocumentMatch> for DocumentResult {
),
request_body = String,
responses(
(status = OK, body = QueryResponse),
(status = 202, body = QueryStartResponse),
(status = 400, description = "Wrong parameter."),
(status = 500, description = "Failure to query database.")
)
)]
pub async fn route(
pub async fn start(
Query(params): Query<QueryParams>,
State(state): State<Arc<AppState>>,
body: String,
) -> Result<Json<QueryResponse>, QueryError> {
) -> Result<Json<QueryStartResponse>, QueryError> {
let limit = params.limit.unwrap_or(5);
ensure!(limit <= MAX_LIMIT, LimitSnafu);
let results = query::ask(
&body,
&state.db,
&state.tokenizer,
&state.embedder,
&state.reranker,
state.chunk_size,
limit,
)
.await
.context(QuerySnafu)?;
let response = QueryResponse::from((results, body));
let jobs = state.jobs.clone();
let id = jobs
.execute(move || async move {
query::ask(
&body,
&state.db.queries,
&state.tokenizer,
&state.embedder,
&state.reranker,
state.chunk_size,
limit,
)
.await
})
.await
.context(QueryExecuteSnafu)?;
Ok(Json(response))
Ok(Json(QueryStartResponse { id: id.to_string() }))
}
/// Errors that occur during query retrieval.
#[derive(Debug, Snafu)]
pub enum RetrieveError {
#[snafu(display("Failed to retrieve job from database."))]
Db {
source: storage::jobs::RetrieveError,
},
#[snafu(display("No job with id {id}."))]
NoJob { id: Uuid },
#[snafu(display("'{id}' is not a valid v4 UUID."))]
Uuid { source: uuid::Error, id: String },
}
impl HttpStatus for RetrieveError {
fn status_code(&self) -> StatusCode {
match self {
RetrieveError::Db { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
RetrieveError::NoJob { id: _ } => StatusCode::NOT_FOUND,
RetrieveError::Uuid { source: _, id: _ } => StatusCode::BAD_REQUEST,
}
}
}
http_error!(RetrieveError);
#[utoipa::path(
get,
path = "/query/{id}",
tag = TAG,
responses(
(status = OK, body = QueryResponse),
(status = 404, description = "No job with the requested id."),
(status = 500, description = "Failure to query database.")
)
)]
pub async fn result(
Path(id): Path<String>,
State(state): State<Arc<AppState>>,
) -> Result<Json<QueryResponse>, RetrieveError> {
let id = Uuid::from_str(&id).context(UuidSnafu { id })?;
match state.db.jobs.retrieve(&id).await.context(DbSnafu)? {
Some((status, results)) => Ok(Json(QueryResponse {
status,
results: results.into_iter().map(Into::into).collect(),
})),
None => NoJobSnafu { id }.fail(),
}
}

View file

@ -12,7 +12,8 @@ use crate::api::{code, query};
pub fn router(state: AppState) -> OpenApiRouter {
let store = Arc::new(state);
OpenApiRouter::new()
.routes(routes!(query::route, code::route))
.routes(routes!(query::start, query::result))
.routes(routes!(code::route))
.layer(TraceLayer::new_for_http())
.with_state(store)
}

View file

@ -2,6 +2,8 @@
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
use super::jobs::JobManager;
/// Application state shared across all HTTP request handlers.
#[derive(Debug, Clone)]
pub struct AppState {
@ -15,4 +17,6 @@ pub struct AppState {
pub reranker: TextEncoder,
/// Text chunk size in words for processing.
pub chunk_size: usize,
/// Background jobs handling.
pub jobs: JobManager,
}

View file

@ -75,7 +75,7 @@ async fn main() -> Result<(), Error> {
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
query::ask(
&config.query,
&db,
&db.queries,
&tokenizer,
&embedder,
&reranker,
@ -87,7 +87,7 @@ async fn main() -> Result<(), Error> {
}
Commands::Scan(config) => {
scanner::scan(
&db,
&db.queries,
&config.library_path,
&tokenizer,
&embedder,

View file

@ -3,7 +3,10 @@
use snafu::{ResultExt, Snafu};
use crate::{
storage::{self, DocumentMatch, Postgres},
storage::{
self,
queries::{DocumentMatch, Queries},
},
text_encoder::{self, TextEncoder},
tokenize::{self, Tokenizer},
};
@ -16,7 +19,9 @@ pub enum AskError {
#[snafu(display("Failed to embed query."))]
Embed { source: text_encoder::EmbedError },
#[snafu(display("Failed to retrieve similar documents."))]
Query { source: storage::QueryError },
Query {
source: storage::queries::QueryError,
},
#[snafu(display("Failed to rerank documents."))]
Rerank { source: text_encoder::RerankError },
}
@ -24,7 +29,7 @@ pub enum AskError {
/// Process a user query and return ranked document matches.
pub async fn ask(
query: &str,
db: &Postgres,
db: &Queries,
tokenizer: &Tokenizer,
embedder: &TextEncoder,
reranker: &TextEncoder,

View file

@ -9,14 +9,14 @@ use crate::{
calibre::{Book, BookIterator},
extractors::{self, extractor::ExtractionError},
hash,
storage::{self, Postgres},
storage::{self, queries::Queries},
text_encoder::{self, TextEncoder},
tokenize::{self, Tokenizer},
};
/// Scan a Calibre library and process all books for embedding generation.
pub async fn scan(
db: &Postgres,
db: &Queries,
library_path: &PathBuf,
tokenizer: &Tokenizer,
text_encoder: &TextEncoder,
@ -43,7 +43,9 @@ pub enum ProcessBookError {
#[snafu(display("Failed to hash book."))]
BookHash { source: io::Error },
#[snafu(display("Failed to check in db if document needs update."))]
UpdateCheck { source: storage::CheckUpdateError },
UpdateCheck {
source: storage::queries::CheckUpdateError,
},
#[snafu(display("Failed to open book."))]
Open { source: io::Error },
#[snafu(display("Failed to extract document content."))]
@ -53,21 +55,27 @@ pub enum ProcessBookError {
#[snafu(display("Failed to create embeddings from encodings."))]
Embedding { source: text_encoder::EmbedError },
#[snafu(display("Failed to look up document versions."))]
DocumentVersion { source: storage::VersionError },
DocumentVersion {
source: storage::queries::VersionError,
},
#[snafu(display("Failed to store embeddings in database."))]
SaveEmbeddings { source: storage::EmbeddingError },
SaveEmbeddings {
source: storage::queries::EmbeddingError,
},
#[snafu(display("Failed to update document hash in database."))]
UpdateHash { source: storage::UpdateHashError },
UpdateHash {
source: storage::queries::UpdateHashError,
},
#[snafu(display("Failed to delete old document versions in database."))]
RemoveOldVersions {
source: storage::RemoveOldVersionsError,
source: storage::queries::RemoveOldVersionsError,
},
}
/// Process a single book: extract text, generate embeddings, and store in database.
async fn process_book(
book: &Book,
db: &Postgres,
db: &Queries,
tokenizer: &Tokenizer,
embedder: &TextEncoder,
chunk_size: usize,

View file

@ -1,27 +1,20 @@
//! Database storage and retrieval for document embeddings.
use futures::future::try_join_all;
use std::sync::Arc;
use jobs::Jobs;
use queries::Queries;
use snafu::{ResultExt, Snafu};
use sqlx::{FromRow, Pool, postgres::PgPoolOptions};
use sqlx::postgres::PgPoolOptions;
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
/// A document chunk with similarity score from vector search.
#[derive(Debug, FromRow)]
pub struct DocumentMatch {
/// Calibre book ID.
pub book_id: i64,
/// Text content of the chunk.
pub text_chunk: String,
/// Cosine similarity score (0.0 to 1.0).
pub similarity: f64,
}
pub mod jobs;
pub mod queries;
/// PostgreSQL database connection pool and operations.
#[derive(Debug, Clone)]
pub struct Postgres {
/// Connection pool for database operations.
pool: Pool<sqlx::Postgres>,
pub jobs: Jobs,
pub queries: Queries,
}
/// Error when connecting to database.
@ -31,195 +24,20 @@ pub struct ConnectionError {
source: sqlx::Error,
}
/// Error when retrieving document version.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to get highest document version."))]
pub struct VersionError {
source: sqlx::Error,
}
/// Error when storing embeddings.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to store embeddings."))]
pub struct EmbeddingError {
source: sqlx::Error,
}
/// Error when querying for similar documents.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to query similarity data."))]
pub struct QueryError {
source: sqlx::Error,
}
/// Error when updating document hash.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to update document hash."))]
pub struct UpdateHashError {
source: sqlx::Error,
}
/// Error when checking if document needs update.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to check if document needs update."))]
pub struct CheckUpdateError {
source: sqlx::Error,
}
/// Error when removing old document versions.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to delete old document versions."))]
pub struct RemoveOldVersionsError {
source: sqlx::Error,
}
impl Postgres {
/// Create a new database connection pool.
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.connect(db_url)
.await
.context(ConnectionSnafu)?;
let pool = Arc::new(
PgPoolOptions::new()
.max_connections(max_connections)
.connect(db_url)
.await
.context(ConnectionSnafu)?,
);
Ok(Self { pool })
}
let jobs = Jobs::new(pool.clone());
let queries = Queries::new(pool.clone());
/// Get the highest version number for a book's embeddings.
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
let version = sqlx::query_scalar(
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
)
.bind(book_id)
.fetch_one(&self.pool)
.await
.context(VersionSnafu)?;
Ok(version)
}
/// Store multiple embeddings for a book in parallel.
pub async fn add_multiple_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Vec<Embeddings>,
) -> Result<(), EmbeddingError> {
let inserts: Vec<_> = embeddings
.into_iter()
.map(|e| self.add_embeddings(book_id, version, e))
.collect();
try_join_all(inserts).await?;
Ok(())
}
/// Store a single embedding for a book.
pub async fn add_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Embeddings,
) -> Result<(), EmbeddingError> {
let vector_string = embeddings_to_vec_string(embeddings.0);
sqlx::query(
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
)
.bind(book_id)
.bind(version)
.bind(vector_string)
.bind(embeddings.2)
.execute(&self.pool)
.await
.context(EmbeddingSnafu)?;
Ok(())
}
/// Find documents similar to the query embedding using cosine similarity.
pub async fn query(
&self,
query: Embeddings,
limit: i32,
) -> Result<Vec<DocumentMatch>, QueryError> {
let vector_string = embeddings_to_vec_string(query.0);
let books = sqlx::query_as::<_, DocumentMatch>(
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
FROM documents
ORDER BY embedding <=> $1::vector
LIMIT $2",
)
.bind(vector_string)
.bind(limit)
.fetch_all(&self.pool)
.await
.context(QuerySnafu)?;
Ok(books)
}
/// Update or insert the hash for a book.
pub async fn update_hash(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<(), UpdateHashError> {
sqlx::query(
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
)
.bind(book_id)
.bind(&hash[..])
.execute(&self.pool)
.await
.context(UpdateHashSnafu)?;
Ok(())
}
/// Check if a book needs to be reprocessed based on its hash.
pub async fn book_needs_update(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<bool, CheckUpdateError> {
let exists = sqlx::query_scalar::<_, bool>(
"SELECT NOT EXISTS (
SELECT 1 FROM hashes
WHERE book_id = $1 AND hash = $2
)",
)
.bind(book_id)
.bind(&hash[..])
.fetch_one(&self.pool)
.await
.context(CheckUpdateSnafu)?;
Ok(exists)
}
/// Remove all but the latest version of embeddings for a book.
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
sqlx::query(
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
)
.bind(book_id)
.execute(&self.pool)
.await
.context(RemoveOldVersionsSnafu)?;
Ok(())
Ok(Self { jobs, queries })
}
}
/// Convert a vector of floats to PostgreSQL vector format string.
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
format!(
"[{}]",
vector
.iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join(",")
)
}

143
src/storage/jobs.rs Normal file
View file

@ -0,0 +1,143 @@
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use sqlx::{PgPool, Row, Type};
use uuid::Uuid;
use super::queries::DocumentMatch;
#[derive(Debug, Clone)]
pub struct Jobs {
/// Connection pool for database operations.
pool: Arc<PgPool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Type)]
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
pub enum JobStatus {
Running,
Completed,
Failed,
}
impl std::fmt::Display for JobStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JobStatus::Running => write!(f, "running"),
JobStatus::Completed => write!(f, "completed"),
JobStatus::Failed => write!(f, "failed"),
}
}
}
impl std::str::FromStr for JobStatus {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"running" => Ok(JobStatus::Running),
"completed" => Ok(JobStatus::Completed),
"failed" => Ok(JobStatus::Failed),
_ => Err(format!("Invalid job status: {}", s)),
}
}
}
/// Error adding a new job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add background job."))]
pub struct AddError {
source: sqlx::Error,
}
/// Error when adding a result to a job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add result to background job."))]
pub struct ResultError {
source: sqlx::Error,
}
/// Error when adding a result to a job.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to add failure to background job."))]
pub struct FailureError {
source: sqlx::Error,
}
/// Errors that occur durin job retrieval.
#[derive(Debug, Snafu)]
pub enum RetrieveError {
#[snafu(display("Failed to retrieve background job data"))]
Retrieve { source: sqlx::Error },
#[snafu(display("Failed to run query."))]
Deserialize { source: postcard::Error },
}
impl Jobs {
pub fn new(pool: Arc<PgPool>) -> Self {
Self { pool }
}
/// Add a new job to the database and mark it as running.
pub async fn add(&self, id: &Uuid) -> Result<(), AddError> {
sqlx::query("INSERT INTO jobs (id, expires_at) VALUES ($1, NOW() + INTERVAL '24 hours')")
.bind(id)
.execute(self.pool.as_ref())
.await
.context(AddSnafu)?;
Ok(())
}
/// Add a result to a job and mark it as finished.
pub async fn finish(&self, id: &Uuid, data: &[u8]) -> Result<(), ResultError> {
sqlx::query("UPDATE jobs SET status = 'completed', result = $1 WHERE id = $2")
.bind(data)
.bind(id)
.execute(self.pool.as_ref())
.await
.context(ResultSnafu)?;
Ok(())
}
/// Add an error to a job and mark it as failed.
pub async fn fail(&self, id: &Uuid, error: &str) -> Result<(), FailureError> {
sqlx::query("UPDATE jobs SET status = 'failed', error = $1 WHERE id = $2")
.bind(error)
.bind(id)
.execute(self.pool.as_ref())
.await
.context(FailureSnafu)?;
Ok(())
}
pub async fn retrieve(
&self,
id: &Uuid,
) -> Result<Option<(String, Vec<DocumentMatch>)>, RetrieveError> {
let row = sqlx::query("SELECT status, result FROM jobs WHERE id = $1")
.bind(id)
.fetch_optional(self.pool.as_ref())
.await
.context(RetrieveSnafu)?;
match row {
Some(row) => {
let status: JobStatus = row.get("status");
let status = status.to_string();
match row.try_get("result") {
Ok(data) => {
let documents: Vec<DocumentMatch> =
postcard::from_bytes(data).context(DeserializeSnafu)?;
Ok(Some((status, documents)))
}
Err(_) => Ok(Some((status, Vec::new()))),
}
}
None => Ok(None),
}
}
}

211
src/storage/queries.rs Normal file
View file

@ -0,0 +1,211 @@
use std::sync::Arc;
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use sqlx::{FromRow, PgPool};
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
#[derive(Debug, Clone)]
pub struct Queries {
/// Connection pool for database operations.
pool: Arc<PgPool>,
}
/// Error when retrieving document version.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to get highest document version."))]
pub struct VersionError {
source: sqlx::Error,
}
/// Error when storing embeddings.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to store embeddings."))]
pub struct EmbeddingError {
source: sqlx::Error,
}
/// Error when querying for similar documents.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to query similarity data."))]
pub struct QueryError {
source: sqlx::Error,
}
/// Error when updating document hash.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to update document hash."))]
pub struct UpdateHashError {
source: sqlx::Error,
}
/// Error when checking if document needs update.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to check if document needs update."))]
pub struct CheckUpdateError {
source: sqlx::Error,
}
/// Error when removing old document versions.
#[derive(Debug, Snafu)]
#[snafu(display("Failed to delete old document versions."))]
pub struct RemoveOldVersionsError {
source: sqlx::Error,
}
/// A document chunk with similarity score from vector search.
#[derive(Debug, FromRow, Serialize, Deserialize)]
pub struct DocumentMatch {
/// Calibre book ID.
pub book_id: i64,
/// Text content of the chunk.
pub text_chunk: String,
/// Cosine similarity score (0.0 to 1.0).
pub similarity: f64,
}
impl Queries {
pub fn new(pool: Arc<PgPool>) -> Self {
Self { pool }
}
/// Get the highest version number for a book's embeddings.
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
let version = sqlx::query_scalar(
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
)
.bind(book_id)
.fetch_one(self.pool.as_ref())
.await
.context(VersionSnafu)?;
Ok(version)
}
/// Store multiple embeddings for a book in parallel.
pub async fn add_multiple_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Vec<Embeddings>,
) -> Result<(), EmbeddingError> {
let inserts: Vec<_> = embeddings
.into_iter()
.map(|e| self.add_embeddings(book_id, version, e))
.collect();
try_join_all(inserts).await?;
Ok(())
}
/// Store a single embedding for a book.
pub async fn add_embeddings(
&self,
book_id: i64,
version: i32,
embeddings: Embeddings,
) -> Result<(), EmbeddingError> {
let vector_string = embeddings_to_vec_string(embeddings.0);
sqlx::query(
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
)
.bind(book_id)
.bind(version)
.bind(vector_string)
.bind(embeddings.2)
.execute(self.pool.as_ref())
.await
.context(EmbeddingSnafu)?;
Ok(())
}
/// Find documents similar to the query embedding using cosine similarity.
pub async fn query(
&self,
query: Embeddings,
limit: i32,
) -> Result<Vec<DocumentMatch>, QueryError> {
let vector_string = embeddings_to_vec_string(query.0);
let books = sqlx::query_as::<_, DocumentMatch>(
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
FROM documents
ORDER BY embedding <=> $1::vector
LIMIT $2",
)
.bind(vector_string)
.bind(limit)
.fetch_all(self.pool.as_ref())
.await
.context(QuerySnafu)?;
Ok(books)
}
/// Update or insert the hash for a book.
pub async fn update_hash(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<(), UpdateHashError> {
sqlx::query(
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
)
.bind(book_id)
.bind(&hash[..])
.execute(self.pool.as_ref())
.await
.context(UpdateHashSnafu)?;
Ok(())
}
/// Check if a book needs to be reprocessed based on its hash.
pub async fn book_needs_update(
&self,
book_id: i64,
hash: [u8; SHA256_LENGTH],
) -> Result<bool, CheckUpdateError> {
let exists = sqlx::query_scalar::<_, bool>(
"SELECT NOT EXISTS (
SELECT 1 FROM hashes
WHERE book_id = $1 AND hash = $2
)",
)
.bind(book_id)
.bind(&hash[..])
.fetch_one(self.pool.as_ref())
.await
.context(CheckUpdateSnafu)?;
Ok(exists)
}
/// Remove all but the latest version of embeddings for a book.
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
sqlx::query(
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
)
.bind(book_id)
.execute(self.pool.as_ref())
.await
.context(RemoveOldVersionsSnafu)?;
Ok(())
}
}
/// Convert a vector of floats to PostgreSQL vector format string.
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
format!(
"[{}]",
vector
.iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join(",")
)
}

View file

@ -12,7 +12,7 @@ use snafu::{ResultExt, Snafu};
use tract_onnx::prelude::*;
use crate::{
storage::DocumentMatch,
storage::queries::DocumentMatch,
text_encoder::tract_ndarray::ShapeError,
tokenize::{self, Encoding, Tokenizer},
};