diff --git a/Cargo.lock b/Cargo.lock index 9f6326a..a3ef389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,12 +158,72 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -1058,6 +1118,25 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "h2" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.4.1" @@ -1149,6 +1228,88 @@ dependencies = [ "match_token", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.63" @@ -1294,6 +1455,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", "hashbrown 0.15.3", + "serde", ] [[package]] @@ -1582,6 +1744,7 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" name = "little-librarian" version = "0.1.0" dependencies = [ + "axum", "clap", "futures", "lopdf", @@ -1589,16 +1752,22 @@ dependencies = [ "rayon", "ring", "scraper", + "serde", "sha2", "snafu", "sqlx", "tokenizers", "tokio", + "tower-http", "tracing", "tracing-subscriber", "tract-onnx", + "utoipa", + "utoipa-axum", + "utoipa-swagger-ui", + "uuid", "walkdir", - "zip", + "zip 4.2.0", ] [[package]] @@ -1700,6 +1869,21 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -1735,6 +1919,22 @@ dependencies = [ "libc", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2395,8 +2595,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -2407,9 +2616,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -2450,6 +2665,40 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rust-embed" +version = "8.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "025908b8682a26ba8d12f6f2d66b987584a4a87bc024abc5bbc12553a8cd178a" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "8.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6065f1a4392b71819ec1ea1df1120673418bf386f50de1d6f54204d836d4349c" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "syn 2.0.101", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "8.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6cc0c81648b20b70c491ff8cce00c1c3b223bb8ed2b5d41f0e54c6c4c0a3594" +dependencies = [ + "sha2", + "walkdir", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -2620,6 +2869,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -3069,6 +3328,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -3212,7 +3477,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax", + "regex-syntax 0.8.5", "serde", "serde_json", "spm_precompiled", @@ -3260,6 +3525,63 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "http", + "http-body", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.41" @@ -3310,10 +3632,14 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] @@ -3477,6 +3803,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -3572,6 +3904,79 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utoipa" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993" +dependencies = [ + "indexmap", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-axum" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c25bae5bccc842449ec0c5ddc5cbb6a3a1eaeac4503895dc105a1138f8234a0" +dependencies = [ + "axum", + "paste", + "tower-layer", + "tower-service", + "utoipa", +] + +[[package]] +name = "utoipa-gen" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d79d08d92ab8af4c5e8a6da20c47ae3f61a0f1dabc1997cdf2d082b757ca08b" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.101", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "9.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d047458f1b5b65237c2f6dc6db136945667f40a7668627b3490b9513a3d43a55" +dependencies = [ + "axum", + "base64 0.22.1", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "url", + "utoipa", + "utoipa-swagger-ui-vendored", + "zip 3.0.0", +] + +[[package]] +name = "utoipa-swagger-ui-vendored" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" + +[[package]] +name = "uuid" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" @@ -4104,6 +4509,20 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "zip" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12598812502ed0105f607f941c386f43d441e00148fce9dec3ca5ffb0bde9308" +dependencies = [ + "arbitrary", + "crc32fast", + "flate2", + "indexmap", + "memchr", + "zopfli", +] + [[package]] name = "zip" version = "4.2.0" diff --git a/Cargo.toml b/Cargo.toml index 52bd0df..d44a1e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ license = "AGPL-3.0" edition = "2024" [dependencies] +axum = { version = "0.8.4", features = ["http2", "tracing"] } clap = { version = "4.5.40", features = ["derive"] } futures = "0.3.31" lopdf = "0.36.0" @@ -12,13 +13,19 @@ quick-xml = "0.38.0" rayon = "1.10.0" ring = "0.17.14" scraper = "0.23.1" +serde = { version = "1.0.219", features = ["derive"] } sha2 = "0.10.9" snafu = { version = "0.8.6", features = ["rust_1_81"] } sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] } tokenizers = "0.21.2" tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]} +tower-http = { version = "0.6.6", features = ["trace"] } tracing = "0.1.41" -tracing-subscriber = "0.3.19" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } tract-onnx = "0.21.13" +utoipa = { version = "5.4.0", features = ["axum_extras"] } +utoipa-axum = "0.2.0" +utoipa-swagger-ui = { version = "9.0.2", features = ["axum", "vendored"] } +uuid = { version = "1.17.0", features = ["v4"] } walkdir = "2.5.0" zip = "4.2.0" diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..2456107 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,75 @@ +use std::{ + io, + net::{AddrParseError, SocketAddr}, + str::FromStr, +}; + +use snafu::{ResultExt, Snafu}; +use state::AppState; +use tokio::net::TcpListener; +use utoipa::OpenApi; +use utoipa_axum::router::OpenApiRouter; +use utoipa_swagger_ui::SwaggerUi; + +use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer}; + +pub mod error; +pub mod query; +pub mod routes; +pub mod state; + +const TAG: &str = "little-librarian"; + +#[derive(OpenApi)] +#[openapi( + tags( + (name = TAG, description = "Calibre Semantic Search API") + ) + )] +struct ApiDoc; + +#[derive(Debug, Snafu)] +pub enum ServeError { + #[snafu(display("Failed to parse address into :."))] + AddressParse { source: AddrParseError }, + #[snafu(display("Failed to bind to {address}."))] + Bind { + source: io::Error, + address: SocketAddr, + }, + #[snafu(display("Failed to run http server."))] + Serve { source: io::Error }, +} + +pub async fn serve( + address: &str, + db: Postgres, + tokenizer: Tokenizer, + embedder: TextEncoder, + reranker: TextEncoder, + chunk_size: usize, +) -> Result<(), ServeError> { + let state = AppState { + db, + tokenizer, + embedder, + reranker, + chunk_size, + }; + + let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) + .nest("/api/v1/", routes::router(state)) + .split_for_parts(); + + let router = + router.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", api.clone())); + + let address = SocketAddr::from_str(address).context(AddressParseSnafu)?; + let listener = TcpListener::bind(&address) + .await + .context(BindSnafu { address })?; + + axum::serve(listener, router.into_make_service()) + .await + .context(ServeSnafu) +} diff --git a/src/api/error.rs b/src/api/error.rs new file mode 100644 index 0000000..15fbe91 --- /dev/null +++ b/src/api/error.rs @@ -0,0 +1,32 @@ +use axum::http::StatusCode; +use serde::Serialize; + +#[derive(Serialize)] +pub struct ErrorResponse { + pub id: String, + pub error: String, +} + +pub trait HttpStatus { + fn status_code(&self) -> StatusCode; +} + +#[macro_export] +macro_rules! http_error { + ($error_type:ty) => { + impl axum::response::IntoResponse for $error_type { + fn into_response(self) -> axum::response::Response { + let status = self.status_code(); + let id = uuid::Uuid::new_v4().to_string(); + tracing::error!("{}: {}", &id, snafu::Report::from_error(&self)); + + let error_response = $crate::api::error::ErrorResponse { + id, + error: self.to_string(), + }; + + (status, axum::Json(error_response)).into_response() + } + } + }; +} diff --git a/src/api/query.rs b/src/api/query.rs new file mode 100644 index 0000000..65185ed --- /dev/null +++ b/src/api/query.rs @@ -0,0 +1,114 @@ +use std::sync::Arc; + +use axum::{ + Json, + extract::{Query, State}, + http::StatusCode, +}; +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt, Snafu, ensure}; +use utoipa::ToSchema; + +use super::{TAG, error::HttpStatus, state::AppState}; +use crate::{http_error, query, storage::DocumentMatch}; + +const MAX_LIMIT: usize = 10; + +#[derive(Debug, Snafu)] +pub enum QueryError { + #[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))] + Limit, + #[snafu(display("Failed to run query."))] + Query { source: query::AskError }, +} + +impl HttpStatus for QueryError { + fn status_code(&self) -> StatusCode { + match self { + QueryError::Limit => StatusCode::BAD_REQUEST, + QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +http_error!(QueryError); + +#[derive(Deserialize)] +pub struct QueryParams { + pub limit: Option, +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct QueryResponse { + pub results: Vec, + pub count: usize, + pub query: String, +} + +impl From<(Vec, String)> for QueryResponse { + fn from((documents, query): (Vec, String)) -> Self { + let results: Vec = documents.into_iter().map(Into::into).collect(); + let count = results.len(); + + Self { + results, + count, + query, + } + } +} + +#[derive(Debug, Serialize, ToSchema)] +pub struct DocumentResult { + pub book_id: i64, + pub text_chunk: String, + pub similarity: f64, +} + +impl From for DocumentResult { + fn from(doc: DocumentMatch) -> Self { + Self { + book_id: doc.book_id, + text_chunk: doc.text_chunk, + similarity: doc.similarity, + } + } +} + +#[utoipa::path( + post, + path = "/query", + tag = TAG, + params( + ("limit" = Option, Query, description = "Maximum number of results") + ), + request_body = String, + responses( + (status = OK, body = QueryResponse), + (status = 400, description = "Wrong parameter."), + (status = 500, description = "Failure to query database.") + ) +)] +pub async fn route( + Query(params): Query, + State(state): State>, + body: String, +) -> Result, QueryError> { + let limit = params.limit.unwrap_or(5); + ensure!(limit <= MAX_LIMIT, LimitSnafu); + + let results = query::ask( + &body, + &state.db, + &state.tokenizer, + &state.embedder, + &state.reranker, + state.chunk_size, + limit, + ) + .await + .context(QuerySnafu)?; + let response = QueryResponse::from((results, body)); + + Ok(Json(response)) +} diff --git a/src/api/routes.rs b/src/api/routes.rs new file mode 100644 index 0000000..50f1036 --- /dev/null +++ b/src/api/routes.rs @@ -0,0 +1,15 @@ +use std::sync::Arc; + +use tower_http::trace::TraceLayer; +use utoipa_axum::{router::OpenApiRouter, routes}; + +use super::state::AppState; +use crate::api::query; + +pub fn router(state: AppState) -> OpenApiRouter { + let store = Arc::new(state); + OpenApiRouter::new() + .routes(routes!(query::route)) + .layer(TraceLayer::new_for_http()) + .with_state(store) +} diff --git a/src/api/state.rs b/src/api/state.rs new file mode 100644 index 0000000..6871e05 --- /dev/null +++ b/src/api/state.rs @@ -0,0 +1,10 @@ +use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer}; + +#[derive(Debug, Clone)] +pub struct AppState { + pub db: Postgres, + pub tokenizer: Tokenizer, + pub embedder: TextEncoder, + pub reranker: TextEncoder, + pub chunk_size: usize, +} diff --git a/src/cli.rs b/src/cli.rs index 1b54f1c..e40025f 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -30,7 +30,7 @@ pub struct Cli { pub command: Commands, } -/// Configuration for the ask command. +/// Run a query. #[derive(Args, Clone)] pub struct AskConfig { /// Query to run. @@ -46,7 +46,7 @@ pub struct AskConfig { pub reranking_path: PathBuf, } -/// Configuration for the scan command. +/// Process a calibre library. #[derive(Args, Clone)] pub struct ScanConfig { /// Root directory of calibre library. @@ -54,9 +54,22 @@ pub struct ScanConfig { pub library_path: PathBuf, } +/// Serve the query interface as an HTTP API. +#[derive(Args, Clone)] +pub struct ServeConfig { + /// Address to listen on. + #[arg(short, long, default_value = "[::]:8080")] + pub address: String, + + /// Path to reranking model. + #[arg(short, long, default_value = "./reranking.onnx")] + pub reranking_path: PathBuf, +} + /// Available CLI commands. #[derive(Subcommand)] pub enum Commands { Ask(AskConfig), Scan(ScanConfig), + Serve(ServeConfig), } diff --git a/src/lib.rs b/src/lib.rs index bb2b548..2483682 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod api; pub mod calibre; pub mod cli; pub mod extractors { diff --git a/src/main.rs b/src/main.rs index c90ea1b..751e230 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use std::io; use clap::Parser; use little_librarian::{ + api, cli::{Cli, Commands}, query, scanner::{self}, @@ -12,6 +13,7 @@ use little_librarian::{ tokenize::{self, Tokenizer}, }; use snafu::{ResultExt, Snafu}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; /// Top-level application errors. #[derive(Debug, Snafu)] @@ -26,16 +28,33 @@ pub enum Error { ModelLoading { source: text_encoder::NewFromFileError, }, + #[snafu(display("Failed to load reranker model."))] + RerankerLoading { + source: text_encoder::NewFromFileError, + }, #[snafu(display("Failed to connect to database."))] DbConnection { source: storage::ConnectionError }, #[snafu(display("Failed to process query."))] Ask { source: query::AskError }, + #[snafu(display("Failed to serve HTTP API."))] + Serve { source: api::ServeError }, } #[tokio::main] #[snafu::report] async fn main() -> Result<(), Error> { - tracing_subscriber::fmt::init(); + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); let cli = Cli::parse(); let db_url = if cli.db_url == "-" { @@ -52,12 +71,14 @@ async fn main() -> Result<(), Error> { match &cli.command { Commands::Ask(config) => { + let reranker = + TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?; query::ask( &config.query, &db, &tokenizer, &embedder, - &config.reranking_path, + &reranker, cli.chunk_size, config.limit, ) @@ -74,6 +95,20 @@ async fn main() -> Result<(), Error> { ) .await; } + Commands::Serve(config) => { + let reranker = + TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?; + api::serve( + &config.address, + db, + tokenizer, + embedder, + reranker, + cli.chunk_size, + ) + .await + .context(ServeSnafu)?; + } } Ok(()) diff --git a/src/query.rs b/src/query.rs index 131afa3..5a9adc3 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,11 +1,9 @@ //! Query processing and document retrieval. -use std::path::PathBuf; - use snafu::{ResultExt, Snafu}; use crate::{ - storage::{self, Postgres}, + storage::{self, DocumentMatch, Postgres}, text_encoder::{self, TextEncoder}, tokenize::{self, Tokenizer}, }; @@ -19,10 +17,6 @@ pub enum AskError { Embed { source: text_encoder::EmbedError }, #[snafu(display("Failed to retrieve similar documents."))] Query { source: storage::QueryError }, - #[snafu(display("Failed to load reranker model."))] - LoadReranker { - source: text_encoder::NewFromFileError, - }, #[snafu(display("Failed to rerank documents."))] Rerank { source: text_encoder::RerankError }, } @@ -32,34 +26,21 @@ pub async fn ask( query: &str, db: &Postgres, tokenizer: &Tokenizer, - text_encoder: &TextEncoder, - reranker_path: &PathBuf, + embedder: &TextEncoder, + reranker: &TextEncoder, chunk_size: usize, limit: usize, -) -> Result<(), AskError> { +) -> Result, AskError> { let encodings = tokenizer.encode(query, chunk_size).context(EncodeSnafu)?; - let embeddings = text_encoder - .embed(encodings[0].clone()) - .context(EmbedSnafu)?; + let embeddings = embedder.embed(encodings[0].clone()).context(EmbedSnafu)?; let documents = db .query(embeddings, (limit * 10) as i32) .await .context(QuerySnafu)?; - let reranker = TextEncoder::from_file(reranker_path).context(LoadRerankerSnafu)?; let reranked_docs = reranker .rerank(query, documents, tokenizer, limit) .context(RerankSnafu)?; - for (i, doc) in reranked_docs.iter().enumerate() { - println!( - "{}. Book ID: {}, Score: {:.3}", - i + 1, - doc.book_id, - doc.similarity - ); - println!(" {}\n", doc.text_chunk); - } - - Ok(()) + Ok(reranked_docs) } diff --git a/src/storage.rs b/src/storage.rs index 6824f3b..dbf3fe4 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -18,6 +18,7 @@ pub struct DocumentMatch { } /// PostgreSQL database connection pool and operations. +#[derive(Debug, Clone)] pub struct Postgres { /// Connection pool for database operations. pool: Pool, diff --git a/src/text_encoder.rs b/src/text_encoder.rs index 5865488..87f9335 100644 --- a/src/text_encoder.rs +++ b/src/text_encoder.rs @@ -28,6 +28,7 @@ type Model = SimplePlan< pub type Embeddings = (Vec, Option, String); /// ONNX-based text embedder for generating vector representations. +#[derive(Debug, Clone)] pub struct TextEncoder { /// Compiled ONNX model for inference. model: Model, diff --git a/src/tokenize.rs b/src/tokenize.rs index 2541df6..01c7b09 100644 --- a/src/tokenize.rs +++ b/src/tokenize.rs @@ -28,6 +28,7 @@ pub enum EncodeError { } /// Wrapper around a tokenizer that handles text chunking. +#[derive(Debug, Clone)] pub struct Tokenizer { /// The underlying tokenizer implementation. inner: tokenizers::Tokenizer,