WIP job queue
This commit is contained in:
parent
525e278a4e
commit
6a5b309391
15 changed files with 685 additions and 256 deletions
92
Cargo.lock
generated
92
Cargo.lock
generated
|
@ -158,6 +158,15 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-polyfill"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
|
||||
dependencies = [
|
||||
"critical-section",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
|
@ -449,6 +458,15 @@ version = "0.7.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||
|
||||
[[package]]
|
||||
name = "cobs"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
|
||||
dependencies = [
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
|
@ -543,6 +561,12 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "critical-section"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
|
@ -842,6 +866,18 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embedded-io"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
|
||||
|
||||
[[package]]
|
||||
name = "embedded-io"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "1.0.0"
|
||||
|
@ -1171,6 +1207,15 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hash32"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
|
@ -1200,6 +1245,20 @@ dependencies = [
|
|||
"hashbrown 0.15.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heapless"
|
||||
version = "0.7.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
|
||||
dependencies = [
|
||||
"atomic-polyfill",
|
||||
"hash32",
|
||||
"rustc_version",
|
||||
"serde",
|
||||
"spin",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
|
@ -1788,6 +1847,7 @@ dependencies = [
|
|||
"futures",
|
||||
"ignore",
|
||||
"lopdf",
|
||||
"postcard",
|
||||
"quick-xml",
|
||||
"rayon",
|
||||
"ring",
|
||||
|
@ -2411,6 +2471,19 @@ dependencies = [
|
|||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postcard"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c1de96e20f51df24ca73cafcc4690e044854d803259db27a00a461cb3b9d17a"
|
||||
dependencies = [
|
||||
"cobs",
|
||||
"embedded-io 0.4.0",
|
||||
"embedded-io 0.6.1",
|
||||
"heapless",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "potential_utf"
|
||||
version = "0.1.2"
|
||||
|
@ -2745,6 +2818,15 @@ version = "0.1.24"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustfft"
|
||||
version = "6.3.0"
|
||||
|
@ -2877,6 +2959,12 @@ dependencies = [
|
|||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.219"
|
||||
|
@ -3125,6 +3213,7 @@ dependencies = [
|
|||
"tokio-stream",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
|
@ -3205,6 +3294,7 @@ dependencies = [
|
|||
"stringprep",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
|
@ -3242,6 +3332,7 @@ dependencies = [
|
|||
"stringprep",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
|
@ -3267,6 +3358,7 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -9,6 +9,7 @@ axum = { version = "0.8.4", features = ["http2", "tracing"] }
|
|||
clap = { version = "4.5.40", features = ["derive"] }
|
||||
futures = "0.3.31"
|
||||
lopdf = "0.36.0"
|
||||
postcard = { version = "1.1.2", features = ["alloc"] }
|
||||
quick-xml = "0.38.0"
|
||||
rayon = "1.10.0"
|
||||
ring = "0.17.14"
|
||||
|
@ -16,7 +17,7 @@ scraper = "0.23.1"
|
|||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
sha2 = "0.10.9"
|
||||
snafu = { version = "0.8.6", features = ["rust_1_81"] }
|
||||
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] }
|
||||
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive", "uuid", "macros" ] }
|
||||
tokenizers = "0.21.2"
|
||||
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
|
||||
tower-http = { version = "0.6.6", features = ["trace"] }
|
||||
|
|
12
migrations/20250701133953_jobs.sql
Normal file
12
migrations/20250701133953_jobs.sql
Normal file
|
@ -0,0 +1,12 @@
|
|||
CREATE TYPE job_status AS ENUM ('running', 'completed', 'failed');
|
||||
|
||||
CREATE TABLE jobs (
|
||||
id UUID PRIMARY KEY,
|
||||
status job_status NOT NULL DEFAULT 'running',
|
||||
error TEXT,
|
||||
result BYTEA,
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_jobs_status ON jobs(status);
|
||||
CREATE INDEX idx_jobs_expires_at ON jobs(expires_at);
|
|
@ -4,8 +4,10 @@ use std::{
|
|||
io,
|
||||
net::{AddrParseError, SocketAddr},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use jobs::JobManager;
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use state::AppState;
|
||||
use tokio::net::TcpListener;
|
||||
|
@ -17,6 +19,7 @@ use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
|
|||
|
||||
pub mod code;
|
||||
pub mod error;
|
||||
pub mod jobs;
|
||||
pub mod query;
|
||||
pub mod routes;
|
||||
pub mod state;
|
||||
|
@ -55,12 +58,14 @@ pub async fn serve(
|
|||
reranker: TextEncoder,
|
||||
chunk_size: usize,
|
||||
) -> Result<(), ServeError> {
|
||||
let jobs = JobManager::new(Arc::new(db.jobs.clone()));
|
||||
let state = AppState {
|
||||
db,
|
||||
tokenizer,
|
||||
embedder,
|
||||
reranker,
|
||||
chunk_size,
|
||||
jobs,
|
||||
};
|
||||
|
||||
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())
|
||||
|
|
81
src/api/jobs.rs
Normal file
81
src/api/jobs.rs
Normal file
|
@ -0,0 +1,81 @@
|
|||
use snafu::Report;
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::error;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::storage;
|
||||
use crate::storage::jobs::AddError;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct JobManager {
|
||||
db: Arc<storage::jobs::Jobs>,
|
||||
}
|
||||
|
||||
/// Errors that occur when running a background job.
|
||||
#[derive(Debug, Snafu)]
|
||||
pub enum ExecuteError {
|
||||
#[snafu(display("Failed to add new background job."))]
|
||||
Add { source: AddError },
|
||||
}
|
||||
|
||||
impl JobManager {
|
||||
pub fn new(db: Arc<storage::jobs::Jobs>) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
pub async fn execute<F, Fut, T, E>(&self, task: F) -> Result<Uuid, ExecuteError>
|
||||
where
|
||||
F: FnOnce() -> Fut + Send + 'static,
|
||||
Fut: Future<Output = Result<T, E>> + Send,
|
||||
T: serde::Serialize + Send + Sync + 'static,
|
||||
E: snafu::Error + Send + Sync + 'static,
|
||||
{
|
||||
let job_id = Uuid::new_v4();
|
||||
self.db.as_ref().add(&job_id).await.context(AddSnafu)?;
|
||||
let db = self.db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match task().await {
|
||||
Ok(result) => success(&job_id, &result, &db).await,
|
||||
Err(error) => failure(&job_id, error, &db).await,
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(300)).await;
|
||||
});
|
||||
|
||||
Ok(job_id)
|
||||
}
|
||||
}
|
||||
|
||||
async fn success<T>(job_id: &Uuid, result: &T, db: &storage::jobs::Jobs)
|
||||
where
|
||||
T: serde::Serialize + Send + Sync + 'static,
|
||||
{
|
||||
let data = match postcard::to_allocvec(result) {
|
||||
Ok(data) => data,
|
||||
Err(error) => {
|
||||
failure(job_id, error, db).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = db.finish(job_id, &data).await {
|
||||
failure(job_id, error, db).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn failure<E>(job_id: &Uuid, error: E, db: &storage::jobs::Jobs)
|
||||
where
|
||||
E: snafu::Error + Send + Sync + 'static,
|
||||
{
|
||||
let error = Report::from_error(&error);
|
||||
error!("{job_id}: {error}");
|
||||
if let Err(error) = db.fail(job_id, &error.to_string()).await {
|
||||
error!(
|
||||
"Failed to save job {job_id} result: {}",
|
||||
Report::from_error(&error)
|
||||
);
|
||||
}
|
||||
}
|
126
src/api/query.rs
126
src/api/query.rs
|
@ -1,18 +1,22 @@
|
|||
//! Query endpoint handlers and response types.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Query, State},
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{ResultExt, Snafu, ensure};
|
||||
use utoipa::ToSchema;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{TAG, error::HttpStatus, state::AppState};
|
||||
use crate::{http_error, query, storage::DocumentMatch};
|
||||
use super::{TAG, error::HttpStatus, jobs, state::AppState};
|
||||
use crate::{
|
||||
http_error, query,
|
||||
storage::{self, queries::DocumentMatch},
|
||||
};
|
||||
|
||||
const MAX_LIMIT: usize = 10;
|
||||
|
||||
|
@ -22,14 +26,14 @@ pub enum QueryError {
|
|||
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
|
||||
Limit,
|
||||
#[snafu(display("Failed to run query."))]
|
||||
Query { source: query::AskError },
|
||||
QueryExecute { source: jobs::ExecuteError },
|
||||
}
|
||||
|
||||
impl HttpStatus for QueryError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
QueryError::Limit => StatusCode::BAD_REQUEST,
|
||||
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
QueryError::QueryExecute { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -46,25 +50,9 @@ pub struct QueryParams {
|
|||
/// Response format for successful query requests.
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct QueryResponse {
|
||||
pub status: String,
|
||||
/// List of matching document chunks.
|
||||
pub results: Vec<DocumentResult>,
|
||||
/// Total number of results returned.
|
||||
pub count: usize,
|
||||
/// Original query text that was searched.
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
|
||||
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
|
||||
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
|
||||
let count = results.len();
|
||||
|
||||
Self {
|
||||
results,
|
||||
count,
|
||||
query,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single document search result.
|
||||
|
@ -88,6 +76,13 @@ impl From<DocumentMatch> for DocumentResult {
|
|||
}
|
||||
}
|
||||
|
||||
/// Response format for successful query requests.
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct QueryStartResponse {
|
||||
///.Job id.
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Execute a semantic search query against the document database.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
|
@ -98,31 +93,84 @@ impl From<DocumentMatch> for DocumentResult {
|
|||
),
|
||||
request_body = String,
|
||||
responses(
|
||||
(status = OK, body = QueryResponse),
|
||||
(status = 202, body = QueryStartResponse),
|
||||
(status = 400, description = "Wrong parameter."),
|
||||
(status = 500, description = "Failure to query database.")
|
||||
)
|
||||
)]
|
||||
pub async fn route(
|
||||
pub async fn start(
|
||||
Query(params): Query<QueryParams>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
body: String,
|
||||
) -> Result<Json<QueryResponse>, QueryError> {
|
||||
) -> Result<Json<QueryStartResponse>, QueryError> {
|
||||
let limit = params.limit.unwrap_or(5);
|
||||
ensure!(limit <= MAX_LIMIT, LimitSnafu);
|
||||
|
||||
let results = query::ask(
|
||||
&body,
|
||||
&state.db,
|
||||
&state.tokenizer,
|
||||
&state.embedder,
|
||||
&state.reranker,
|
||||
state.chunk_size,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
.context(QuerySnafu)?;
|
||||
let response = QueryResponse::from((results, body));
|
||||
let jobs = state.jobs.clone();
|
||||
let id = jobs
|
||||
.execute(move || async move {
|
||||
query::ask(
|
||||
&body,
|
||||
&state.db.queries,
|
||||
&state.tokenizer,
|
||||
&state.embedder,
|
||||
&state.reranker,
|
||||
state.chunk_size,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.context(QueryExecuteSnafu)?;
|
||||
|
||||
Ok(Json(response))
|
||||
Ok(Json(QueryStartResponse { id: id.to_string() }))
|
||||
}
|
||||
|
||||
/// Errors that occur during query retrieval.
|
||||
#[derive(Debug, Snafu)]
|
||||
pub enum RetrieveError {
|
||||
#[snafu(display("Failed to retrieve job from database."))]
|
||||
Db {
|
||||
source: storage::jobs::RetrieveError,
|
||||
},
|
||||
#[snafu(display("No job with id {id}."))]
|
||||
NoJob { id: Uuid },
|
||||
#[snafu(display("'{id}' is not a valid v4 UUID."))]
|
||||
Uuid { source: uuid::Error, id: String },
|
||||
}
|
||||
|
||||
impl HttpStatus for RetrieveError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
RetrieveError::Db { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RetrieveError::NoJob { id: _ } => StatusCode::NOT_FOUND,
|
||||
RetrieveError::Uuid { source: _, id: _ } => StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
http_error!(RetrieveError);
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/query/{id}",
|
||||
tag = TAG,
|
||||
responses(
|
||||
(status = OK, body = QueryResponse),
|
||||
(status = 404, description = "No job with the requested id."),
|
||||
(status = 500, description = "Failure to query database.")
|
||||
)
|
||||
)]
|
||||
pub async fn result(
|
||||
Path(id): Path<String>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<Json<QueryResponse>, RetrieveError> {
|
||||
let id = Uuid::from_str(&id).context(UuidSnafu { id })?;
|
||||
match state.db.jobs.retrieve(&id).await.context(DbSnafu)? {
|
||||
Some((status, results)) => Ok(Json(QueryResponse {
|
||||
status,
|
||||
results: results.into_iter().map(Into::into).collect(),
|
||||
})),
|
||||
None => NoJobSnafu { id }.fail(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,8 @@ use crate::api::{code, query};
|
|||
pub fn router(state: AppState) -> OpenApiRouter {
|
||||
let store = Arc::new(state);
|
||||
OpenApiRouter::new()
|
||||
.routes(routes!(query::route, code::route))
|
||||
.routes(routes!(query::start, query::result))
|
||||
.routes(routes!(code::route))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(store)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
|
||||
|
||||
use super::jobs::JobManager;
|
||||
|
||||
/// Application state shared across all HTTP request handlers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppState {
|
||||
|
@ -15,4 +17,6 @@ pub struct AppState {
|
|||
pub reranker: TextEncoder,
|
||||
/// Text chunk size in words for processing.
|
||||
pub chunk_size: usize,
|
||||
/// Background jobs handling.
|
||||
pub jobs: JobManager,
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ async fn main() -> Result<(), Error> {
|
|||
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
|
||||
query::ask(
|
||||
&config.query,
|
||||
&db,
|
||||
&db.queries,
|
||||
&tokenizer,
|
||||
&embedder,
|
||||
&reranker,
|
||||
|
@ -87,7 +87,7 @@ async fn main() -> Result<(), Error> {
|
|||
}
|
||||
Commands::Scan(config) => {
|
||||
scanner::scan(
|
||||
&db,
|
||||
&db.queries,
|
||||
&config.library_path,
|
||||
&tokenizer,
|
||||
&embedder,
|
||||
|
|
11
src/query.rs
11
src/query.rs
|
@ -3,7 +3,10 @@
|
|||
use snafu::{ResultExt, Snafu};
|
||||
|
||||
use crate::{
|
||||
storage::{self, DocumentMatch, Postgres},
|
||||
storage::{
|
||||
self,
|
||||
queries::{DocumentMatch, Queries},
|
||||
},
|
||||
text_encoder::{self, TextEncoder},
|
||||
tokenize::{self, Tokenizer},
|
||||
};
|
||||
|
@ -16,7 +19,9 @@ pub enum AskError {
|
|||
#[snafu(display("Failed to embed query."))]
|
||||
Embed { source: text_encoder::EmbedError },
|
||||
#[snafu(display("Failed to retrieve similar documents."))]
|
||||
Query { source: storage::QueryError },
|
||||
Query {
|
||||
source: storage::queries::QueryError,
|
||||
},
|
||||
#[snafu(display("Failed to rerank documents."))]
|
||||
Rerank { source: text_encoder::RerankError },
|
||||
}
|
||||
|
@ -24,7 +29,7 @@ pub enum AskError {
|
|||
/// Process a user query and return ranked document matches.
|
||||
pub async fn ask(
|
||||
query: &str,
|
||||
db: &Postgres,
|
||||
db: &Queries,
|
||||
tokenizer: &Tokenizer,
|
||||
embedder: &TextEncoder,
|
||||
reranker: &TextEncoder,
|
||||
|
|
|
@ -9,14 +9,14 @@ use crate::{
|
|||
calibre::{Book, BookIterator},
|
||||
extractors::{self, extractor::ExtractionError},
|
||||
hash,
|
||||
storage::{self, Postgres},
|
||||
storage::{self, queries::Queries},
|
||||
text_encoder::{self, TextEncoder},
|
||||
tokenize::{self, Tokenizer},
|
||||
};
|
||||
|
||||
/// Scan a Calibre library and process all books for embedding generation.
|
||||
pub async fn scan(
|
||||
db: &Postgres,
|
||||
db: &Queries,
|
||||
library_path: &PathBuf,
|
||||
tokenizer: &Tokenizer,
|
||||
text_encoder: &TextEncoder,
|
||||
|
@ -43,7 +43,9 @@ pub enum ProcessBookError {
|
|||
#[snafu(display("Failed to hash book."))]
|
||||
BookHash { source: io::Error },
|
||||
#[snafu(display("Failed to check in db if document needs update."))]
|
||||
UpdateCheck { source: storage::CheckUpdateError },
|
||||
UpdateCheck {
|
||||
source: storage::queries::CheckUpdateError,
|
||||
},
|
||||
#[snafu(display("Failed to open book."))]
|
||||
Open { source: io::Error },
|
||||
#[snafu(display("Failed to extract document content."))]
|
||||
|
@ -53,21 +55,27 @@ pub enum ProcessBookError {
|
|||
#[snafu(display("Failed to create embeddings from encodings."))]
|
||||
Embedding { source: text_encoder::EmbedError },
|
||||
#[snafu(display("Failed to look up document versions."))]
|
||||
DocumentVersion { source: storage::VersionError },
|
||||
DocumentVersion {
|
||||
source: storage::queries::VersionError,
|
||||
},
|
||||
#[snafu(display("Failed to store embeddings in database."))]
|
||||
SaveEmbeddings { source: storage::EmbeddingError },
|
||||
SaveEmbeddings {
|
||||
source: storage::queries::EmbeddingError,
|
||||
},
|
||||
#[snafu(display("Failed to update document hash in database."))]
|
||||
UpdateHash { source: storage::UpdateHashError },
|
||||
UpdateHash {
|
||||
source: storage::queries::UpdateHashError,
|
||||
},
|
||||
#[snafu(display("Failed to delete old document versions in database."))]
|
||||
RemoveOldVersions {
|
||||
source: storage::RemoveOldVersionsError,
|
||||
source: storage::queries::RemoveOldVersionsError,
|
||||
},
|
||||
}
|
||||
|
||||
/// Process a single book: extract text, generate embeddings, and store in database.
|
||||
async fn process_book(
|
||||
book: &Book,
|
||||
db: &Postgres,
|
||||
db: &Queries,
|
||||
tokenizer: &Tokenizer,
|
||||
embedder: &TextEncoder,
|
||||
chunk_size: usize,
|
||||
|
|
220
src/storage.rs
220
src/storage.rs
|
@ -1,27 +1,20 @@
|
|||
//! Database storage and retrieval for document embeddings.
|
||||
|
||||
use futures::future::try_join_all;
|
||||
use std::sync::Arc;
|
||||
|
||||
use jobs::Jobs;
|
||||
use queries::Queries;
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use sqlx::{FromRow, Pool, postgres::PgPoolOptions};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
|
||||
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
||||
|
||||
/// A document chunk with similarity score from vector search.
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct DocumentMatch {
|
||||
/// Calibre book ID.
|
||||
pub book_id: i64,
|
||||
/// Text content of the chunk.
|
||||
pub text_chunk: String,
|
||||
/// Cosine similarity score (0.0 to 1.0).
|
||||
pub similarity: f64,
|
||||
}
|
||||
pub mod jobs;
|
||||
pub mod queries;
|
||||
|
||||
/// PostgreSQL database connection pool and operations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Postgres {
|
||||
/// Connection pool for database operations.
|
||||
pool: Pool<sqlx::Postgres>,
|
||||
pub jobs: Jobs,
|
||||
pub queries: Queries,
|
||||
}
|
||||
|
||||
/// Error when connecting to database.
|
||||
|
@ -31,195 +24,20 @@ pub struct ConnectionError {
|
|||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when retrieving document version.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to get highest document version."))]
|
||||
pub struct VersionError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when storing embeddings.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to store embeddings."))]
|
||||
pub struct EmbeddingError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when querying for similar documents.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to query similarity data."))]
|
||||
pub struct QueryError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when updating document hash.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to update document hash."))]
|
||||
pub struct UpdateHashError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when checking if document needs update.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to check if document needs update."))]
|
||||
pub struct CheckUpdateError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when removing old document versions.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to delete old document versions."))]
|
||||
pub struct RemoveOldVersionsError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
impl Postgres {
|
||||
/// Create a new database connection pool.
|
||||
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(db_url)
|
||||
.await
|
||||
.context(ConnectionSnafu)?;
|
||||
let pool = Arc::new(
|
||||
PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(db_url)
|
||||
.await
|
||||
.context(ConnectionSnafu)?,
|
||||
);
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
let jobs = Jobs::new(pool.clone());
|
||||
let queries = Queries::new(pool.clone());
|
||||
|
||||
/// Get the highest version number for a book's embeddings.
|
||||
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
||||
let version = sqlx::query_scalar(
|
||||
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
||||
)
|
||||
.bind(book_id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.context(VersionSnafu)?;
|
||||
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
/// Store multiple embeddings for a book in parallel.
|
||||
pub async fn add_multiple_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Vec<Embeddings>,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let inserts: Vec<_> = embeddings
|
||||
.into_iter()
|
||||
.map(|e| self.add_embeddings(book_id, version, e))
|
||||
.collect();
|
||||
|
||||
try_join_all(inserts).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store a single embedding for a book.
|
||||
pub async fn add_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Embeddings,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let vector_string = embeddings_to_vec_string(embeddings.0);
|
||||
sqlx::query(
|
||||
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(version)
|
||||
.bind(vector_string)
|
||||
.bind(embeddings.2)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(EmbeddingSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find documents similar to the query embedding using cosine similarity.
|
||||
pub async fn query(
|
||||
&self,
|
||||
query: Embeddings,
|
||||
limit: i32,
|
||||
) -> Result<Vec<DocumentMatch>, QueryError> {
|
||||
let vector_string = embeddings_to_vec_string(query.0);
|
||||
let books = sqlx::query_as::<_, DocumentMatch>(
|
||||
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
||||
FROM documents
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $2",
|
||||
)
|
||||
.bind(vector_string)
|
||||
.bind(limit)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.context(QuerySnafu)?;
|
||||
|
||||
Ok(books)
|
||||
}
|
||||
|
||||
/// Update or insert the hash for a book.
|
||||
pub async fn update_hash(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<(), UpdateHashError> {
|
||||
sqlx::query(
|
||||
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
||||
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(UpdateHashSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a book needs to be reprocessed based on its hash.
|
||||
pub async fn book_needs_update(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<bool, CheckUpdateError> {
|
||||
let exists = sqlx::query_scalar::<_, bool>(
|
||||
"SELECT NOT EXISTS (
|
||||
SELECT 1 FROM hashes
|
||||
WHERE book_id = $1 AND hash = $2
|
||||
)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.context(CheckUpdateSnafu)?;
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
/// Remove all but the latest version of embeddings for a book.
|
||||
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
||||
sqlx::query(
|
||||
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(RemoveOldVersionsSnafu)?;
|
||||
|
||||
Ok(())
|
||||
Ok(Self { jobs, queries })
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of floats to PostgreSQL vector format string.
|
||||
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
||||
format!(
|
||||
"[{}]",
|
||||
vector
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(",")
|
||||
)
|
||||
}
|
||||
|
|
143
src/storage/jobs.rs
Normal file
143
src/storage/jobs.rs
Normal file
|
@ -0,0 +1,143 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use sqlx::{PgPool, Row, Type};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::queries::DocumentMatch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Jobs {
|
||||
/// Connection pool for database operations.
|
||||
pool: Arc<PgPool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Type)]
|
||||
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
|
||||
pub enum JobStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for JobStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
JobStatus::Running => write!(f, "running"),
|
||||
JobStatus::Completed => write!(f, "completed"),
|
||||
JobStatus::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for JobStatus {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"running" => Ok(JobStatus::Running),
|
||||
"completed" => Ok(JobStatus::Completed),
|
||||
"failed" => Ok(JobStatus::Failed),
|
||||
_ => Err(format!("Invalid job status: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error adding a new job.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to add background job."))]
|
||||
pub struct AddError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when adding a result to a job.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to add result to background job."))]
|
||||
pub struct ResultError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when adding a result to a job.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to add failure to background job."))]
|
||||
pub struct FailureError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Errors that occur durin job retrieval.
|
||||
#[derive(Debug, Snafu)]
|
||||
pub enum RetrieveError {
|
||||
#[snafu(display("Failed to retrieve background job data"))]
|
||||
Retrieve { source: sqlx::Error },
|
||||
#[snafu(display("Failed to run query."))]
|
||||
Deserialize { source: postcard::Error },
|
||||
}
|
||||
|
||||
impl Jobs {
|
||||
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Add a new job to the database and mark it as running.
|
||||
pub async fn add(&self, id: &Uuid) -> Result<(), AddError> {
|
||||
sqlx::query("INSERT INTO jobs (id, expires_at) VALUES ($1, NOW() + INTERVAL '24 hours')")
|
||||
.bind(id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(AddSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a result to a job and mark it as finished.
|
||||
pub async fn finish(&self, id: &Uuid, data: &[u8]) -> Result<(), ResultError> {
|
||||
sqlx::query("UPDATE jobs SET status = 'completed', result = $1 WHERE id = $2")
|
||||
.bind(data)
|
||||
.bind(id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(ResultSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add an error to a job and mark it as failed.
|
||||
pub async fn fail(&self, id: &Uuid, error: &str) -> Result<(), FailureError> {
|
||||
sqlx::query("UPDATE jobs SET status = 'failed', error = $1 WHERE id = $2")
|
||||
.bind(error)
|
||||
.bind(id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(FailureSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn retrieve(
|
||||
&self,
|
||||
id: &Uuid,
|
||||
) -> Result<Option<(String, Vec<DocumentMatch>)>, RetrieveError> {
|
||||
let row = sqlx::query("SELECT status, result FROM jobs WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await
|
||||
.context(RetrieveSnafu)?;
|
||||
|
||||
match row {
|
||||
Some(row) => {
|
||||
let status: JobStatus = row.get("status");
|
||||
let status = status.to_string();
|
||||
match row.try_get("result") {
|
||||
Ok(data) => {
|
||||
let documents: Vec<DocumentMatch> =
|
||||
postcard::from_bytes(data).context(DeserializeSnafu)?;
|
||||
Ok(Some((status, documents)))
|
||||
}
|
||||
Err(_) => Ok(Some((status, Vec::new()))),
|
||||
}
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
211
src/storage/queries.rs
Normal file
211
src/storage/queries.rs
Normal file
|
@ -0,0 +1,211 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures::future::try_join_all;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
|
||||
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Queries {
|
||||
/// Connection pool for database operations.
|
||||
pool: Arc<PgPool>,
|
||||
}
|
||||
|
||||
/// Error when retrieving document version.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to get highest document version."))]
|
||||
pub struct VersionError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when storing embeddings.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to store embeddings."))]
|
||||
pub struct EmbeddingError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when querying for similar documents.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to query similarity data."))]
|
||||
pub struct QueryError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when updating document hash.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to update document hash."))]
|
||||
pub struct UpdateHashError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when checking if document needs update.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to check if document needs update."))]
|
||||
pub struct CheckUpdateError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when removing old document versions.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to delete old document versions."))]
|
||||
pub struct RemoveOldVersionsError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// A document chunk with similarity score from vector search.
|
||||
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
||||
pub struct DocumentMatch {
|
||||
/// Calibre book ID.
|
||||
pub book_id: i64,
|
||||
/// Text content of the chunk.
|
||||
pub text_chunk: String,
|
||||
/// Cosine similarity score (0.0 to 1.0).
|
||||
pub similarity: f64,
|
||||
}
|
||||
|
||||
impl Queries {
|
||||
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Get the highest version number for a book's embeddings.
|
||||
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
||||
let version = sqlx::query_scalar(
|
||||
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
||||
)
|
||||
.bind(book_id)
|
||||
.fetch_one(self.pool.as_ref())
|
||||
.await
|
||||
.context(VersionSnafu)?;
|
||||
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
/// Store multiple embeddings for a book in parallel.
|
||||
pub async fn add_multiple_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Vec<Embeddings>,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let inserts: Vec<_> = embeddings
|
||||
.into_iter()
|
||||
.map(|e| self.add_embeddings(book_id, version, e))
|
||||
.collect();
|
||||
|
||||
try_join_all(inserts).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store a single embedding for a book.
|
||||
pub async fn add_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Embeddings,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let vector_string = embeddings_to_vec_string(embeddings.0);
|
||||
sqlx::query(
|
||||
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(version)
|
||||
.bind(vector_string)
|
||||
.bind(embeddings.2)
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(EmbeddingSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find documents similar to the query embedding using cosine similarity.
|
||||
pub async fn query(
|
||||
&self,
|
||||
query: Embeddings,
|
||||
limit: i32,
|
||||
) -> Result<Vec<DocumentMatch>, QueryError> {
|
||||
let vector_string = embeddings_to_vec_string(query.0);
|
||||
let books = sqlx::query_as::<_, DocumentMatch>(
|
||||
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
||||
FROM documents
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $2",
|
||||
)
|
||||
.bind(vector_string)
|
||||
.bind(limit)
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await
|
||||
.context(QuerySnafu)?;
|
||||
|
||||
Ok(books)
|
||||
}
|
||||
|
||||
/// Update or insert the hash for a book.
|
||||
pub async fn update_hash(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<(), UpdateHashError> {
|
||||
sqlx::query(
|
||||
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
||||
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(UpdateHashSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a book needs to be reprocessed based on its hash.
|
||||
pub async fn book_needs_update(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<bool, CheckUpdateError> {
|
||||
let exists = sqlx::query_scalar::<_, bool>(
|
||||
"SELECT NOT EXISTS (
|
||||
SELECT 1 FROM hashes
|
||||
WHERE book_id = $1 AND hash = $2
|
||||
)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.fetch_one(self.pool.as_ref())
|
||||
.await
|
||||
.context(CheckUpdateSnafu)?;
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
/// Remove all but the latest version of embeddings for a book.
|
||||
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
||||
sqlx::query(
|
||||
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await
|
||||
.context(RemoveOldVersionsSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of floats to PostgreSQL vector format string.
|
||||
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
||||
format!(
|
||||
"[{}]",
|
||||
vector
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(",")
|
||||
)
|
||||
}
|
|
@ -12,7 +12,7 @@ use snafu::{ResultExt, Snafu};
|
|||
use tract_onnx::prelude::*;
|
||||
|
||||
use crate::{
|
||||
storage::DocumentMatch,
|
||||
storage::queries::DocumentMatch,
|
||||
text_encoder::tract_ndarray::ShapeError,
|
||||
tokenize::{self, Encoding, Tokenizer},
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue