WIP job queue
This commit is contained in:
parent
525e278a4e
commit
6a5b309391
15 changed files with 685 additions and 256 deletions
92
Cargo.lock
generated
92
Cargo.lock
generated
|
@ -158,6 +158,15 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atomic-polyfill"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
|
||||||
|
dependencies = [
|
||||||
|
"critical-section",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atomic-waker"
|
name = "atomic-waker"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
|
@ -449,6 +458,15 @@ version = "0.7.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cobs"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorchoice"
|
name = "colorchoice"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
@ -543,6 +561,12 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "critical-section"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
|
@ -842,6 +866,18 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "embedded-io"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "embedded-io"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encode_unicode"
|
name = "encode_unicode"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
|
@ -1171,6 +1207,15 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hash32"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.14.5"
|
version = "0.14.5"
|
||||||
|
@ -1200,6 +1245,20 @@ dependencies = [
|
||||||
"hashbrown 0.15.3",
|
"hashbrown 0.15.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "heapless"
|
||||||
|
version = "0.7.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
|
||||||
|
dependencies = [
|
||||||
|
"atomic-polyfill",
|
||||||
|
"hash32",
|
||||||
|
"rustc_version",
|
||||||
|
"serde",
|
||||||
|
"spin",
|
||||||
|
"stable_deref_trait",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
@ -1788,6 +1847,7 @@ dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"ignore",
|
"ignore",
|
||||||
"lopdf",
|
"lopdf",
|
||||||
|
"postcard",
|
||||||
"quick-xml",
|
"quick-xml",
|
||||||
"rayon",
|
"rayon",
|
||||||
"ring",
|
"ring",
|
||||||
|
@ -2411,6 +2471,19 @@ dependencies = [
|
||||||
"portable-atomic",
|
"portable-atomic",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "postcard"
|
||||||
|
version = "1.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c1de96e20f51df24ca73cafcc4690e044854d803259db27a00a461cb3b9d17a"
|
||||||
|
dependencies = [
|
||||||
|
"cobs",
|
||||||
|
"embedded-io 0.4.0",
|
||||||
|
"embedded-io 0.6.1",
|
||||||
|
"heapless",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "potential_utf"
|
name = "potential_utf"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
|
@ -2745,6 +2818,15 @@ version = "0.1.24"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc_version"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||||
|
dependencies = [
|
||||||
|
"semver",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustfft"
|
name = "rustfft"
|
||||||
version = "6.3.0"
|
version = "6.3.0"
|
||||||
|
@ -2877,6 +2959,12 @@ dependencies = [
|
||||||
"smallvec",
|
"smallvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "semver"
|
||||||
|
version = "1.0.26"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.219"
|
version = "1.0.219"
|
||||||
|
@ -3125,6 +3213,7 @@ dependencies = [
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
|
"uuid",
|
||||||
"webpki-roots 0.26.11",
|
"webpki-roots 0.26.11",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3205,6 +3294,7 @@ dependencies = [
|
||||||
"stringprep",
|
"stringprep",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"uuid",
|
||||||
"whoami",
|
"whoami",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3242,6 +3332,7 @@ dependencies = [
|
||||||
"stringprep",
|
"stringprep",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"uuid",
|
||||||
"whoami",
|
"whoami",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3267,6 +3358,7 @@ dependencies = [
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -9,6 +9,7 @@ axum = { version = "0.8.4", features = ["http2", "tracing"] }
|
||||||
clap = { version = "4.5.40", features = ["derive"] }
|
clap = { version = "4.5.40", features = ["derive"] }
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
lopdf = "0.36.0"
|
lopdf = "0.36.0"
|
||||||
|
postcard = { version = "1.1.2", features = ["alloc"] }
|
||||||
quick-xml = "0.38.0"
|
quick-xml = "0.38.0"
|
||||||
rayon = "1.10.0"
|
rayon = "1.10.0"
|
||||||
ring = "0.17.14"
|
ring = "0.17.14"
|
||||||
|
@ -16,7 +17,7 @@ scraper = "0.23.1"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
sha2 = "0.10.9"
|
sha2 = "0.10.9"
|
||||||
snafu = { version = "0.8.6", features = ["rust_1_81"] }
|
snafu = { version = "0.8.6", features = ["rust_1_81"] }
|
||||||
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] }
|
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive", "uuid", "macros" ] }
|
||||||
tokenizers = "0.21.2"
|
tokenizers = "0.21.2"
|
||||||
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
|
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
|
||||||
tower-http = { version = "0.6.6", features = ["trace"] }
|
tower-http = { version = "0.6.6", features = ["trace"] }
|
||||||
|
|
12
migrations/20250701133953_jobs.sql
Normal file
12
migrations/20250701133953_jobs.sql
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
CREATE TYPE job_status AS ENUM ('running', 'completed', 'failed');
|
||||||
|
|
||||||
|
CREATE TABLE jobs (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
status job_status NOT NULL DEFAULT 'running',
|
||||||
|
error TEXT,
|
||||||
|
result BYTEA,
|
||||||
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_jobs_status ON jobs(status);
|
||||||
|
CREATE INDEX idx_jobs_expires_at ON jobs(expires_at);
|
|
@ -4,8 +4,10 @@ use std::{
|
||||||
io,
|
io,
|
||||||
net::{AddrParseError, SocketAddr},
|
net::{AddrParseError, SocketAddr},
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use jobs::JobManager;
|
||||||
use snafu::{ResultExt, Snafu};
|
use snafu::{ResultExt, Snafu};
|
||||||
use state::AppState;
|
use state::AppState;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
@ -17,6 +19,7 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
|
||||||
|
|
||||||
pub mod code;
|
pub mod code;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod jobs;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod routes;
|
pub mod routes;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
@ -55,12 +58,14 @@ pub async fn serve(
|
||||||
reranker: TextEncoder,
|
reranker: TextEncoder,
|
||||||
chunk_size: usize,
|
chunk_size: usize,
|
||||||
) -> Result<(), ServeError> {
|
) -> Result<(), ServeError> {
|
||||||
|
let jobs = JobManager::new(Arc::new(db.jobs.clone()));
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
db,
|
db,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
embedder,
|
embedder,
|
||||||
reranker,
|
reranker,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
jobs,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())
|
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())
|
||||||
|
|
81
src/api/jobs.rs
Normal file
81
src/api/jobs.rs
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
use snafu::Report;
|
||||||
|
use snafu::{ResultExt, Snafu};
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::storage;
|
||||||
|
use crate::storage::jobs::AddError;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct JobManager {
|
||||||
|
db: Arc<storage::jobs::Jobs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors that occur when running a background job.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
pub enum ExecuteError {
|
||||||
|
#[snafu(display("Failed to add new background job."))]
|
||||||
|
Add { source: AddError },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JobManager {
|
||||||
|
pub fn new(db: Arc<storage::jobs::Jobs>) -> Self {
|
||||||
|
Self { db }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn execute<F, Fut, T, E>(&self, task: F) -> Result<Uuid, ExecuteError>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Fut + Send + 'static,
|
||||||
|
Fut: Future<Output = Result<T, E>> + Send,
|
||||||
|
T: serde::Serialize + Send + Sync + 'static,
|
||||||
|
E: snafu::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let job_id = Uuid::new_v4();
|
||||||
|
self.db.as_ref().add(&job_id).await.context(AddSnafu)?;
|
||||||
|
let db = self.db.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match task().await {
|
||||||
|
Ok(result) => success(&job_id, &result, &db).await,
|
||||||
|
Err(error) => failure(&job_id, error, &db).await,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_secs(300)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(job_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn success<T>(job_id: &Uuid, result: &T, db: &storage::jobs::Jobs)
|
||||||
|
where
|
||||||
|
T: serde::Serialize + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let data = match postcard::to_allocvec(result) {
|
||||||
|
Ok(data) => data,
|
||||||
|
Err(error) => {
|
||||||
|
failure(job_id, error, db).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(error) = db.finish(job_id, &data).await {
|
||||||
|
failure(job_id, error, db).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn failure<E>(job_id: &Uuid, error: E, db: &storage::jobs::Jobs)
|
||||||
|
where
|
||||||
|
E: snafu::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let error = Report::from_error(&error);
|
||||||
|
error!("{job_id}: {error}");
|
||||||
|
if let Err(error) = db.fail(job_id, &error.to_string()).await {
|
||||||
|
error!(
|
||||||
|
"Failed to save job {job_id} result: {}",
|
||||||
|
Report::from_error(&error)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
126
src/api/query.rs
126
src/api/query.rs
|
@ -1,18 +1,22 @@
|
||||||
//! Query endpoint handlers and response types.
|
//! Query endpoint handlers and response types.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{str::FromStr, sync::Arc};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
Json,
|
Json,
|
||||||
extract::{Query, State},
|
extract::{Path, Query, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use snafu::{ResultExt, Snafu, ensure};
|
use snafu::{ResultExt, Snafu, ensure};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::{TAG, error::HttpStatus, state::AppState};
|
use super::{TAG, error::HttpStatus, jobs, state::AppState};
|
||||||
use crate::{http_error, query, storage::DocumentMatch};
|
use crate::{
|
||||||
|
http_error, query,
|
||||||
|
storage::{self, queries::DocumentMatch},
|
||||||
|
};
|
||||||
|
|
||||||
const MAX_LIMIT: usize = 10;
|
const MAX_LIMIT: usize = 10;
|
||||||
|
|
||||||
|
@ -22,14 +26,14 @@ pub enum QueryError {
|
||||||
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
|
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
|
||||||
Limit,
|
Limit,
|
||||||
#[snafu(display("Failed to run query."))]
|
#[snafu(display("Failed to run query."))]
|
||||||
Query { source: query::AskError },
|
QueryExecute { source: jobs::ExecuteError },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HttpStatus for QueryError {
|
impl HttpStatus for QueryError {
|
||||||
fn status_code(&self) -> StatusCode {
|
fn status_code(&self) -> StatusCode {
|
||||||
match self {
|
match self {
|
||||||
QueryError::Limit => StatusCode::BAD_REQUEST,
|
QueryError::Limit => StatusCode::BAD_REQUEST,
|
||||||
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
QueryError::QueryExecute { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,25 +50,9 @@ pub struct QueryParams {
|
||||||
/// Response format for successful query requests.
|
/// Response format for successful query requests.
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct QueryResponse {
|
pub struct QueryResponse {
|
||||||
|
pub status: String,
|
||||||
/// List of matching document chunks.
|
/// List of matching document chunks.
|
||||||
pub results: Vec<DocumentResult>,
|
pub results: Vec<DocumentResult>,
|
||||||
/// Total number of results returned.
|
|
||||||
pub count: usize,
|
|
||||||
/// Original query text that was searched.
|
|
||||||
pub query: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
|
|
||||||
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
|
|
||||||
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
|
|
||||||
let count = results.len();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
results,
|
|
||||||
count,
|
|
||||||
query,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A single document search result.
|
/// A single document search result.
|
||||||
|
@ -88,6 +76,13 @@ impl From<DocumentMatch> for DocumentResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Response format for successful query requests.
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct QueryStartResponse {
|
||||||
|
///.Job id.
|
||||||
|
pub id: String,
|
||||||
|
}
|
||||||
|
|
||||||
/// Execute a semantic search query against the document database.
|
/// Execute a semantic search query against the document database.
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
@ -98,31 +93,84 @@ impl From<DocumentMatch> for DocumentResult {
|
||||||
),
|
),
|
||||||
request_body = String,
|
request_body = String,
|
||||||
responses(
|
responses(
|
||||||
(status = OK, body = QueryResponse),
|
(status = 202, body = QueryStartResponse),
|
||||||
(status = 400, description = "Wrong parameter."),
|
(status = 400, description = "Wrong parameter."),
|
||||||
(status = 500, description = "Failure to query database.")
|
(status = 500, description = "Failure to query database.")
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn route(
|
pub async fn start(
|
||||||
Query(params): Query<QueryParams>,
|
Query(params): Query<QueryParams>,
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
body: String,
|
body: String,
|
||||||
) -> Result<Json<QueryResponse>, QueryError> {
|
) -> Result<Json<QueryStartResponse>, QueryError> {
|
||||||
let limit = params.limit.unwrap_or(5);
|
let limit = params.limit.unwrap_or(5);
|
||||||
ensure!(limit <= MAX_LIMIT, LimitSnafu);
|
ensure!(limit <= MAX_LIMIT, LimitSnafu);
|
||||||
|
|
||||||
let results = query::ask(
|
let jobs = state.jobs.clone();
|
||||||
&body,
|
let id = jobs
|
||||||
&state.db,
|
.execute(move || async move {
|
||||||
&state.tokenizer,
|
query::ask(
|
||||||
&state.embedder,
|
&body,
|
||||||
&state.reranker,
|
&state.db.queries,
|
||||||
state.chunk_size,
|
&state.tokenizer,
|
||||||
limit,
|
&state.embedder,
|
||||||
)
|
&state.reranker,
|
||||||
.await
|
state.chunk_size,
|
||||||
.context(QuerySnafu)?;
|
limit,
|
||||||
let response = QueryResponse::from((results, body));
|
)
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context(QueryExecuteSnafu)?;
|
||||||
|
|
||||||
Ok(Json(response))
|
Ok(Json(QueryStartResponse { id: id.to_string() }))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors that occur during query retrieval.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
pub enum RetrieveError {
|
||||||
|
#[snafu(display("Failed to retrieve job from database."))]
|
||||||
|
Db {
|
||||||
|
source: storage::jobs::RetrieveError,
|
||||||
|
},
|
||||||
|
#[snafu(display("No job with id {id}."))]
|
||||||
|
NoJob { id: Uuid },
|
||||||
|
#[snafu(display("'{id}' is not a valid v4 UUID."))]
|
||||||
|
Uuid { source: uuid::Error, id: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HttpStatus for RetrieveError {
|
||||||
|
fn status_code(&self) -> StatusCode {
|
||||||
|
match self {
|
||||||
|
RetrieveError::Db { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
RetrieveError::NoJob { id: _ } => StatusCode::NOT_FOUND,
|
||||||
|
RetrieveError::Uuid { source: _, id: _ } => StatusCode::BAD_REQUEST,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
http_error!(RetrieveError);
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/query/{id}",
|
||||||
|
tag = TAG,
|
||||||
|
responses(
|
||||||
|
(status = OK, body = QueryResponse),
|
||||||
|
(status = 404, description = "No job with the requested id."),
|
||||||
|
(status = 500, description = "Failure to query database.")
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn result(
|
||||||
|
Path(id): Path<String>,
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
) -> Result<Json<QueryResponse>, RetrieveError> {
|
||||||
|
let id = Uuid::from_str(&id).context(UuidSnafu { id })?;
|
||||||
|
match state.db.jobs.retrieve(&id).await.context(DbSnafu)? {
|
||||||
|
Some((status, results)) => Ok(Json(QueryResponse {
|
||||||
|
status,
|
||||||
|
results: results.into_iter().map(Into::into).collect(),
|
||||||
|
})),
|
||||||
|
None => NoJobSnafu { id }.fail(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,8 @@ use crate::api::{code, query};
|
||||||
pub fn router(state: AppState) -> OpenApiRouter {
|
pub fn router(state: AppState) -> OpenApiRouter {
|
||||||
let store = Arc::new(state);
|
let store = Arc::new(state);
|
||||||
OpenApiRouter::new()
|
OpenApiRouter::new()
|
||||||
.routes(routes!(query::route, code::route))
|
.routes(routes!(query::start, query::result))
|
||||||
|
.routes(routes!(code::route))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
.with_state(store)
|
.with_state(store)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
|
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
|
||||||
|
|
||||||
|
use super::jobs::JobManager;
|
||||||
|
|
||||||
/// Application state shared across all HTTP request handlers.
|
/// Application state shared across all HTTP request handlers.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
|
@ -15,4 +17,6 @@ pub struct AppState {
|
||||||
pub reranker: TextEncoder,
|
pub reranker: TextEncoder,
|
||||||
/// Text chunk size in words for processing.
|
/// Text chunk size in words for processing.
|
||||||
pub chunk_size: usize,
|
pub chunk_size: usize,
|
||||||
|
/// Background jobs handling.
|
||||||
|
pub jobs: JobManager,
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ async fn main() -> Result<(), Error> {
|
||||||
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
|
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
|
||||||
query::ask(
|
query::ask(
|
||||||
&config.query,
|
&config.query,
|
||||||
&db,
|
&db.queries,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
&embedder,
|
&embedder,
|
||||||
&reranker,
|
&reranker,
|
||||||
|
@ -87,7 +87,7 @@ async fn main() -> Result<(), Error> {
|
||||||
}
|
}
|
||||||
Commands::Scan(config) => {
|
Commands::Scan(config) => {
|
||||||
scanner::scan(
|
scanner::scan(
|
||||||
&db,
|
&db.queries,
|
||||||
&config.library_path,
|
&config.library_path,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
&embedder,
|
&embedder,
|
||||||
|
|
11
src/query.rs
11
src/query.rs
|
@ -3,7 +3,10 @@
|
||||||
use snafu::{ResultExt, Snafu};
|
use snafu::{ResultExt, Snafu};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
storage::{self, DocumentMatch, Postgres},
|
storage::{
|
||||||
|
self,
|
||||||
|
queries::{DocumentMatch, Queries},
|
||||||
|
},
|
||||||
text_encoder::{self, TextEncoder},
|
text_encoder::{self, TextEncoder},
|
||||||
tokenize::{self, Tokenizer},
|
tokenize::{self, Tokenizer},
|
||||||
};
|
};
|
||||||
|
@ -16,7 +19,9 @@ pub enum AskError {
|
||||||
#[snafu(display("Failed to embed query."))]
|
#[snafu(display("Failed to embed query."))]
|
||||||
Embed { source: text_encoder::EmbedError },
|
Embed { source: text_encoder::EmbedError },
|
||||||
#[snafu(display("Failed to retrieve similar documents."))]
|
#[snafu(display("Failed to retrieve similar documents."))]
|
||||||
Query { source: storage::QueryError },
|
Query {
|
||||||
|
source: storage::queries::QueryError,
|
||||||
|
},
|
||||||
#[snafu(display("Failed to rerank documents."))]
|
#[snafu(display("Failed to rerank documents."))]
|
||||||
Rerank { source: text_encoder::RerankError },
|
Rerank { source: text_encoder::RerankError },
|
||||||
}
|
}
|
||||||
|
@ -24,7 +29,7 @@ pub enum AskError {
|
||||||
/// Process a user query and return ranked document matches.
|
/// Process a user query and return ranked document matches.
|
||||||
pub async fn ask(
|
pub async fn ask(
|
||||||
query: &str,
|
query: &str,
|
||||||
db: &Postgres,
|
db: &Queries,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
embedder: &TextEncoder,
|
embedder: &TextEncoder,
|
||||||
reranker: &TextEncoder,
|
reranker: &TextEncoder,
|
||||||
|
|
|
@ -9,14 +9,14 @@ use crate::{
|
||||||
calibre::{Book, BookIterator},
|
calibre::{Book, BookIterator},
|
||||||
extractors::{self, extractor::ExtractionError},
|
extractors::{self, extractor::ExtractionError},
|
||||||
hash,
|
hash,
|
||||||
storage::{self, Postgres},
|
storage::{self, queries::Queries},
|
||||||
text_encoder::{self, TextEncoder},
|
text_encoder::{self, TextEncoder},
|
||||||
tokenize::{self, Tokenizer},
|
tokenize::{self, Tokenizer},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Scan a Calibre library and process all books for embedding generation.
|
/// Scan a Calibre library and process all books for embedding generation.
|
||||||
pub async fn scan(
|
pub async fn scan(
|
||||||
db: &Postgres,
|
db: &Queries,
|
||||||
library_path: &PathBuf,
|
library_path: &PathBuf,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
text_encoder: &TextEncoder,
|
text_encoder: &TextEncoder,
|
||||||
|
@ -43,7 +43,9 @@ pub enum ProcessBookError {
|
||||||
#[snafu(display("Failed to hash book."))]
|
#[snafu(display("Failed to hash book."))]
|
||||||
BookHash { source: io::Error },
|
BookHash { source: io::Error },
|
||||||
#[snafu(display("Failed to check in db if document needs update."))]
|
#[snafu(display("Failed to check in db if document needs update."))]
|
||||||
UpdateCheck { source: storage::CheckUpdateError },
|
UpdateCheck {
|
||||||
|
source: storage::queries::CheckUpdateError,
|
||||||
|
},
|
||||||
#[snafu(display("Failed to open book."))]
|
#[snafu(display("Failed to open book."))]
|
||||||
Open { source: io::Error },
|
Open { source: io::Error },
|
||||||
#[snafu(display("Failed to extract document content."))]
|
#[snafu(display("Failed to extract document content."))]
|
||||||
|
@ -53,21 +55,27 @@ pub enum ProcessBookError {
|
||||||
#[snafu(display("Failed to create embeddings from encodings."))]
|
#[snafu(display("Failed to create embeddings from encodings."))]
|
||||||
Embedding { source: text_encoder::EmbedError },
|
Embedding { source: text_encoder::EmbedError },
|
||||||
#[snafu(display("Failed to look up document versions."))]
|
#[snafu(display("Failed to look up document versions."))]
|
||||||
DocumentVersion { source: storage::VersionError },
|
DocumentVersion {
|
||||||
|
source: storage::queries::VersionError,
|
||||||
|
},
|
||||||
#[snafu(display("Failed to store embeddings in database."))]
|
#[snafu(display("Failed to store embeddings in database."))]
|
||||||
SaveEmbeddings { source: storage::EmbeddingError },
|
SaveEmbeddings {
|
||||||
|
source: storage::queries::EmbeddingError,
|
||||||
|
},
|
||||||
#[snafu(display("Failed to update document hash in database."))]
|
#[snafu(display("Failed to update document hash in database."))]
|
||||||
UpdateHash { source: storage::UpdateHashError },
|
UpdateHash {
|
||||||
|
source: storage::queries::UpdateHashError,
|
||||||
|
},
|
||||||
#[snafu(display("Failed to delete old document versions in database."))]
|
#[snafu(display("Failed to delete old document versions in database."))]
|
||||||
RemoveOldVersions {
|
RemoveOldVersions {
|
||||||
source: storage::RemoveOldVersionsError,
|
source: storage::queries::RemoveOldVersionsError,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a single book: extract text, generate embeddings, and store in database.
|
/// Process a single book: extract text, generate embeddings, and store in database.
|
||||||
async fn process_book(
|
async fn process_book(
|
||||||
book: &Book,
|
book: &Book,
|
||||||
db: &Postgres,
|
db: &Queries,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
embedder: &TextEncoder,
|
embedder: &TextEncoder,
|
||||||
chunk_size: usize,
|
chunk_size: usize,
|
||||||
|
|
220
src/storage.rs
220
src/storage.rs
|
@ -1,27 +1,20 @@
|
||||||
//! Database storage and retrieval for document embeddings.
|
//! Database storage and retrieval for document embeddings.
|
||||||
|
|
||||||
use futures::future::try_join_all;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use jobs::Jobs;
|
||||||
|
use queries::Queries;
|
||||||
use snafu::{ResultExt, Snafu};
|
use snafu::{ResultExt, Snafu};
|
||||||
use sqlx::{FromRow, Pool, postgres::PgPoolOptions};
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
|
||||||
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
pub mod jobs;
|
||||||
|
pub mod queries;
|
||||||
/// A document chunk with similarity score from vector search.
|
|
||||||
#[derive(Debug, FromRow)]
|
|
||||||
pub struct DocumentMatch {
|
|
||||||
/// Calibre book ID.
|
|
||||||
pub book_id: i64,
|
|
||||||
/// Text content of the chunk.
|
|
||||||
pub text_chunk: String,
|
|
||||||
/// Cosine similarity score (0.0 to 1.0).
|
|
||||||
pub similarity: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// PostgreSQL database connection pool and operations.
|
/// PostgreSQL database connection pool and operations.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Postgres {
|
pub struct Postgres {
|
||||||
/// Connection pool for database operations.
|
pub jobs: Jobs,
|
||||||
pool: Pool<sqlx::Postgres>,
|
pub queries: Queries,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error when connecting to database.
|
/// Error when connecting to database.
|
||||||
|
@ -31,195 +24,20 @@ pub struct ConnectionError {
|
||||||
source: sqlx::Error,
|
source: sqlx::Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error when retrieving document version.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to get highest document version."))]
|
|
||||||
pub struct VersionError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error when storing embeddings.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to store embeddings."))]
|
|
||||||
pub struct EmbeddingError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error when querying for similar documents.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to query similarity data."))]
|
|
||||||
pub struct QueryError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error when updating document hash.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to update document hash."))]
|
|
||||||
pub struct UpdateHashError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error when checking if document needs update.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to check if document needs update."))]
|
|
||||||
pub struct CheckUpdateError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error when removing old document versions.
|
|
||||||
#[derive(Debug, Snafu)]
|
|
||||||
#[snafu(display("Failed to delete old document versions."))]
|
|
||||||
pub struct RemoveOldVersionsError {
|
|
||||||
source: sqlx::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Postgres {
|
impl Postgres {
|
||||||
/// Create a new database connection pool.
|
/// Create a new database connection pool.
|
||||||
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
|
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
|
||||||
let pool = PgPoolOptions::new()
|
let pool = Arc::new(
|
||||||
.max_connections(max_connections)
|
PgPoolOptions::new()
|
||||||
.connect(db_url)
|
.max_connections(max_connections)
|
||||||
.await
|
.connect(db_url)
|
||||||
.context(ConnectionSnafu)?;
|
.await
|
||||||
|
.context(ConnectionSnafu)?,
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self { pool })
|
let jobs = Jobs::new(pool.clone());
|
||||||
}
|
let queries = Queries::new(pool.clone());
|
||||||
|
|
||||||
/// Get the highest version number for a book's embeddings.
|
Ok(Self { jobs, queries })
|
||||||
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
|
||||||
let version = sqlx::query_scalar(
|
|
||||||
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
|
||||||
)
|
|
||||||
.bind(book_id)
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(VersionSnafu)?;
|
|
||||||
|
|
||||||
Ok(version)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Store multiple embeddings for a book in parallel.
|
|
||||||
pub async fn add_multiple_embeddings(
|
|
||||||
&self,
|
|
||||||
book_id: i64,
|
|
||||||
version: i32,
|
|
||||||
embeddings: Vec<Embeddings>,
|
|
||||||
) -> Result<(), EmbeddingError> {
|
|
||||||
let inserts: Vec<_> = embeddings
|
|
||||||
.into_iter()
|
|
||||||
.map(|e| self.add_embeddings(book_id, version, e))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
try_join_all(inserts).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Store a single embedding for a book.
|
|
||||||
pub async fn add_embeddings(
|
|
||||||
&self,
|
|
||||||
book_id: i64,
|
|
||||||
version: i32,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
) -> Result<(), EmbeddingError> {
|
|
||||||
let vector_string = embeddings_to_vec_string(embeddings.0);
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
|
||||||
)
|
|
||||||
.bind(book_id)
|
|
||||||
.bind(version)
|
|
||||||
.bind(vector_string)
|
|
||||||
.bind(embeddings.2)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(EmbeddingSnafu)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find documents similar to the query embedding using cosine similarity.
|
|
||||||
pub async fn query(
|
|
||||||
&self,
|
|
||||||
query: Embeddings,
|
|
||||||
limit: i32,
|
|
||||||
) -> Result<Vec<DocumentMatch>, QueryError> {
|
|
||||||
let vector_string = embeddings_to_vec_string(query.0);
|
|
||||||
let books = sqlx::query_as::<_, DocumentMatch>(
|
|
||||||
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
|
||||||
FROM documents
|
|
||||||
ORDER BY embedding <=> $1::vector
|
|
||||||
LIMIT $2",
|
|
||||||
)
|
|
||||||
.bind(vector_string)
|
|
||||||
.bind(limit)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(QuerySnafu)?;
|
|
||||||
|
|
||||||
Ok(books)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update or insert the hash for a book.
|
|
||||||
pub async fn update_hash(
|
|
||||||
&self,
|
|
||||||
book_id: i64,
|
|
||||||
hash: [u8; SHA256_LENGTH],
|
|
||||||
) -> Result<(), UpdateHashError> {
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
|
||||||
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
|
||||||
)
|
|
||||||
.bind(book_id)
|
|
||||||
.bind(&hash[..])
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(UpdateHashSnafu)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a book needs to be reprocessed based on its hash.
|
|
||||||
pub async fn book_needs_update(
|
|
||||||
&self,
|
|
||||||
book_id: i64,
|
|
||||||
hash: [u8; SHA256_LENGTH],
|
|
||||||
) -> Result<bool, CheckUpdateError> {
|
|
||||||
let exists = sqlx::query_scalar::<_, bool>(
|
|
||||||
"SELECT NOT EXISTS (
|
|
||||||
SELECT 1 FROM hashes
|
|
||||||
WHERE book_id = $1 AND hash = $2
|
|
||||||
)",
|
|
||||||
)
|
|
||||||
.bind(book_id)
|
|
||||||
.bind(&hash[..])
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(CheckUpdateSnafu)?;
|
|
||||||
|
|
||||||
Ok(exists)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Remove all but the latest version of embeddings for a book.
|
|
||||||
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
|
||||||
sqlx::query(
|
|
||||||
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
|
||||||
)
|
|
||||||
.bind(book_id)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await
|
|
||||||
.context(RemoveOldVersionsSnafu)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert a vector of floats to PostgreSQL vector format string.
|
|
||||||
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
|
||||||
format!(
|
|
||||||
"[{}]",
|
|
||||||
vector
|
|
||||||
.iter()
|
|
||||||
.map(|x| x.to_string())
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(",")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
143
src/storage/jobs.rs
Normal file
143
src/storage/jobs.rs
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use snafu::{ResultExt, Snafu};
|
||||||
|
use sqlx::{PgPool, Row, Type};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::queries::DocumentMatch;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Jobs {
|
||||||
|
/// Connection pool for database operations.
|
||||||
|
pool: Arc<PgPool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Type)]
|
||||||
|
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
|
||||||
|
pub enum JobStatus {
|
||||||
|
Running,
|
||||||
|
Completed,
|
||||||
|
Failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for JobStatus {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
JobStatus::Running => write!(f, "running"),
|
||||||
|
JobStatus::Completed => write!(f, "completed"),
|
||||||
|
JobStatus::Failed => write!(f, "failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for JobStatus {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s {
|
||||||
|
"running" => Ok(JobStatus::Running),
|
||||||
|
"completed" => Ok(JobStatus::Completed),
|
||||||
|
"failed" => Ok(JobStatus::Failed),
|
||||||
|
_ => Err(format!("Invalid job status: {}", s)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error adding a new job.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to add background job."))]
|
||||||
|
pub struct AddError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when adding a result to a job.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to add result to background job."))]
|
||||||
|
pub struct ResultError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when adding a result to a job.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to add failure to background job."))]
|
||||||
|
pub struct FailureError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors that occur durin job retrieval.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
pub enum RetrieveError {
|
||||||
|
#[snafu(display("Failed to retrieve background job data"))]
|
||||||
|
Retrieve { source: sqlx::Error },
|
||||||
|
#[snafu(display("Failed to run query."))]
|
||||||
|
Deserialize { source: postcard::Error },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Jobs {
|
||||||
|
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new job to the database and mark it as running.
|
||||||
|
pub async fn add(&self, id: &Uuid) -> Result<(), AddError> {
|
||||||
|
sqlx::query("INSERT INTO jobs (id, expires_at) VALUES ($1, NOW() + INTERVAL '24 hours')")
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(AddSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a result to a job and mark it as finished.
|
||||||
|
pub async fn finish(&self, id: &Uuid, data: &[u8]) -> Result<(), ResultError> {
|
||||||
|
sqlx::query("UPDATE jobs SET status = 'completed', result = $1 WHERE id = $2")
|
||||||
|
.bind(data)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(ResultSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add an error to a job and mark it as failed.
|
||||||
|
pub async fn fail(&self, id: &Uuid, error: &str) -> Result<(), FailureError> {
|
||||||
|
sqlx::query("UPDATE jobs SET status = 'failed', error = $1 WHERE id = $2")
|
||||||
|
.bind(error)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(FailureSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn retrieve(
|
||||||
|
&self,
|
||||||
|
id: &Uuid,
|
||||||
|
) -> Result<Option<(String, Vec<DocumentMatch>)>, RetrieveError> {
|
||||||
|
let row = sqlx::query("SELECT status, result FROM jobs WHERE id = $1")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(RetrieveSnafu)?;
|
||||||
|
|
||||||
|
match row {
|
||||||
|
Some(row) => {
|
||||||
|
let status: JobStatus = row.get("status");
|
||||||
|
let status = status.to_string();
|
||||||
|
match row.try_get("result") {
|
||||||
|
Ok(data) => {
|
||||||
|
let documents: Vec<DocumentMatch> =
|
||||||
|
postcard::from_bytes(data).context(DeserializeSnafu)?;
|
||||||
|
Ok(Some((status, documents)))
|
||||||
|
}
|
||||||
|
Err(_) => Ok(Some((status, Vec::new()))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
211
src/storage/queries.rs
Normal file
211
src/storage/queries.rs
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::future::try_join_all;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use snafu::{ResultExt, Snafu};
|
||||||
|
use sqlx::{FromRow, PgPool};
|
||||||
|
|
||||||
|
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Queries {
|
||||||
|
/// Connection pool for database operations.
|
||||||
|
pool: Arc<PgPool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when retrieving document version.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to get highest document version."))]
|
||||||
|
pub struct VersionError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when storing embeddings.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to store embeddings."))]
|
||||||
|
pub struct EmbeddingError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when querying for similar documents.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to query similarity data."))]
|
||||||
|
pub struct QueryError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when updating document hash.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to update document hash."))]
|
||||||
|
pub struct UpdateHashError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when checking if document needs update.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to check if document needs update."))]
|
||||||
|
pub struct CheckUpdateError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when removing old document versions.
|
||||||
|
#[derive(Debug, Snafu)]
|
||||||
|
#[snafu(display("Failed to delete old document versions."))]
|
||||||
|
pub struct RemoveOldVersionsError {
|
||||||
|
source: sqlx::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A document chunk with similarity score from vector search.
|
||||||
|
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
||||||
|
pub struct DocumentMatch {
|
||||||
|
/// Calibre book ID.
|
||||||
|
pub book_id: i64,
|
||||||
|
/// Text content of the chunk.
|
||||||
|
pub text_chunk: String,
|
||||||
|
/// Cosine similarity score (0.0 to 1.0).
|
||||||
|
pub similarity: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Queries {
|
||||||
|
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the highest version number for a book's embeddings.
|
||||||
|
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
||||||
|
let version = sqlx::query_scalar(
|
||||||
|
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
||||||
|
)
|
||||||
|
.bind(book_id)
|
||||||
|
.fetch_one(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(VersionSnafu)?;
|
||||||
|
|
||||||
|
Ok(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store multiple embeddings for a book in parallel.
|
||||||
|
pub async fn add_multiple_embeddings(
|
||||||
|
&self,
|
||||||
|
book_id: i64,
|
||||||
|
version: i32,
|
||||||
|
embeddings: Vec<Embeddings>,
|
||||||
|
) -> Result<(), EmbeddingError> {
|
||||||
|
let inserts: Vec<_> = embeddings
|
||||||
|
.into_iter()
|
||||||
|
.map(|e| self.add_embeddings(book_id, version, e))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
try_join_all(inserts).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store a single embedding for a book.
|
||||||
|
pub async fn add_embeddings(
|
||||||
|
&self,
|
||||||
|
book_id: i64,
|
||||||
|
version: i32,
|
||||||
|
embeddings: Embeddings,
|
||||||
|
) -> Result<(), EmbeddingError> {
|
||||||
|
let vector_string = embeddings_to_vec_string(embeddings.0);
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
||||||
|
)
|
||||||
|
.bind(book_id)
|
||||||
|
.bind(version)
|
||||||
|
.bind(vector_string)
|
||||||
|
.bind(embeddings.2)
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(EmbeddingSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find documents similar to the query embedding using cosine similarity.
|
||||||
|
pub async fn query(
|
||||||
|
&self,
|
||||||
|
query: Embeddings,
|
||||||
|
limit: i32,
|
||||||
|
) -> Result<Vec<DocumentMatch>, QueryError> {
|
||||||
|
let vector_string = embeddings_to_vec_string(query.0);
|
||||||
|
let books = sqlx::query_as::<_, DocumentMatch>(
|
||||||
|
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
||||||
|
FROM documents
|
||||||
|
ORDER BY embedding <=> $1::vector
|
||||||
|
LIMIT $2",
|
||||||
|
)
|
||||||
|
.bind(vector_string)
|
||||||
|
.bind(limit)
|
||||||
|
.fetch_all(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(QuerySnafu)?;
|
||||||
|
|
||||||
|
Ok(books)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update or insert the hash for a book.
|
||||||
|
pub async fn update_hash(
|
||||||
|
&self,
|
||||||
|
book_id: i64,
|
||||||
|
hash: [u8; SHA256_LENGTH],
|
||||||
|
) -> Result<(), UpdateHashError> {
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
||||||
|
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
||||||
|
)
|
||||||
|
.bind(book_id)
|
||||||
|
.bind(&hash[..])
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(UpdateHashSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a book needs to be reprocessed based on its hash.
|
||||||
|
pub async fn book_needs_update(
|
||||||
|
&self,
|
||||||
|
book_id: i64,
|
||||||
|
hash: [u8; SHA256_LENGTH],
|
||||||
|
) -> Result<bool, CheckUpdateError> {
|
||||||
|
let exists = sqlx::query_scalar::<_, bool>(
|
||||||
|
"SELECT NOT EXISTS (
|
||||||
|
SELECT 1 FROM hashes
|
||||||
|
WHERE book_id = $1 AND hash = $2
|
||||||
|
)",
|
||||||
|
)
|
||||||
|
.bind(book_id)
|
||||||
|
.bind(&hash[..])
|
||||||
|
.fetch_one(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(CheckUpdateSnafu)?;
|
||||||
|
|
||||||
|
Ok(exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove all but the latest version of embeddings for a book.
|
||||||
|
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
||||||
|
sqlx::query(
|
||||||
|
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
||||||
|
)
|
||||||
|
.bind(book_id)
|
||||||
|
.execute(self.pool.as_ref())
|
||||||
|
.await
|
||||||
|
.context(RemoveOldVersionsSnafu)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a vector of floats to PostgreSQL vector format string.
|
||||||
|
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
||||||
|
format!(
|
||||||
|
"[{}]",
|
||||||
|
vector
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.to_string())
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(",")
|
||||||
|
)
|
||||||
|
}
|
|
@ -12,7 +12,7 @@ use snafu::{ResultExt, Snafu};
|
||||||
use tract_onnx::prelude::*;
|
use tract_onnx::prelude::*;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
storage::DocumentMatch,
|
storage::queries::DocumentMatch,
|
||||||
text_encoder::tract_ndarray::ShapeError,
|
text_encoder::tract_ndarray::ShapeError,
|
||||||
tokenize::{self, Encoding, Tokenizer},
|
tokenize::{self, Encoding, Tokenizer},
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue