HTTP API WIP

This commit is contained in:
Sebastian Hugentobler 2025-07-01 14:02:18 +02:00
parent 0bd97d0ed3
commit 552fce432b
Signed by: shu
SSH key fingerprint: SHA256:ppcx6MlixdNZd5EUM1nkHOKoyQYoJwzuQKXM6J/t66M
14 changed files with 740 additions and 35 deletions

429
Cargo.lock generated
View file

@ -158,12 +158,72 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "autocfg"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "axum"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backtrace"
version = "0.3.75"
@ -1058,6 +1118,25 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "h2"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "2.4.1"
@ -1149,6 +1228,88 @@ dependencies = [
"match_token",
]
[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-util"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"hyper",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]]
name = "iana-time-zone"
version = "0.1.63"
@ -1294,6 +1455,7 @@ checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
dependencies = [
"equivalent",
"hashbrown 0.15.3",
"serde",
]
[[package]]
@ -1582,6 +1744,7 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
name = "little-librarian"
version = "0.1.0"
dependencies = [
"axum",
"clap",
"futures",
"lopdf",
@ -1589,16 +1752,22 @@ dependencies = [
"rayon",
"ring",
"scraper",
"serde",
"sha2",
"snafu",
"sqlx",
"tokenizers",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
"tract-onnx",
"utoipa",
"utoipa-axum",
"utoipa-swagger-ui",
"uuid",
"walkdir",
"zip",
"zip 4.2.0",
]
[[package]]
@ -1700,6 +1869,21 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
@ -1735,6 +1919,22 @@ dependencies = [
"libc",
]
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2395,8 +2595,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
@ -2407,9 +2616,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
"regex-syntax 0.8.5",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.5"
@ -2450,6 +2665,40 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rust-embed"
version = "8.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "025908b8682a26ba8d12f6f2d66b987584a4a87bc024abc5bbc12553a8cd178a"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "8.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6065f1a4392b71819ec1ea1df1120673418bf386f50de1d6f54204d836d4349c"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.101",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "8.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6cc0c81648b20b70c491ff8cce00c1c3b223bb8ed2b5d41f0e54c6c4c0a3594"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
@ -2620,6 +2869,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -3069,6 +3328,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
[[package]]
name = "synstructure"
version = "0.13.2"
@ -3212,7 +3477,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
@ -3260,6 +3525,63 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tower"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2"
dependencies = [
"bitflags",
"bytes",
"http",
"http-body",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
version = "0.1.41"
@ -3310,10 +3632,14 @@ version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
@ -3477,6 +3803,12 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
@ -3572,6 +3904,79 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "utoipa"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993"
dependencies = [
"indexmap",
"serde",
"serde_json",
"utoipa-gen",
]
[[package]]
name = "utoipa-axum"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c25bae5bccc842449ec0c5ddc5cbb6a3a1eaeac4503895dc105a1138f8234a0"
dependencies = [
"axum",
"paste",
"tower-layer",
"tower-service",
"utoipa",
]
[[package]]
name = "utoipa-gen"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d79d08d92ab8af4c5e8a6da20c47ae3f61a0f1dabc1997cdf2d082b757ca08b"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.101",
]
[[package]]
name = "utoipa-swagger-ui"
version = "9.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d047458f1b5b65237c2f6dc6db136945667f40a7668627b3490b9513a3d43a55"
dependencies = [
"axum",
"base64 0.22.1",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"url",
"utoipa",
"utoipa-swagger-ui-vendored",
"zip 3.0.0",
]
[[package]]
name = "utoipa-swagger-ui-vendored"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d"
[[package]]
name = "uuid"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
dependencies = [
"getrandom 0.3.3",
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "valuable"
version = "0.1.1"
@ -4104,6 +4509,20 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "zip"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12598812502ed0105f607f941c386f43d441e00148fce9dec3ca5ffb0bde9308"
dependencies = [
"arbitrary",
"crc32fast",
"flate2",
"indexmap",
"memchr",
"zopfli",
]
[[package]]
name = "zip"
version = "4.2.0"

View file

@ -5,6 +5,7 @@ license = "AGPL-3.0"
edition = "2024"
[dependencies]
axum = { version = "0.8.4", features = ["http2", "tracing"] }
clap = { version = "4.5.40", features = ["derive"] }
futures = "0.3.31"
lopdf = "0.36.0"
@ -12,13 +13,19 @@ quick-xml = "0.38.0"
rayon = "1.10.0"
ring = "0.17.14"
scraper = "0.23.1"
serde = { version = "1.0.219", features = ["derive"] }
sha2 = "0.10.9"
snafu = { version = "0.8.6", features = ["rust_1_81"] }
sqlx = { version = "0.8.6", features = [ "runtime-tokio", "tls-rustls-ring", "migrate", "postgres", "derive" ] }
tokenizers = "0.21.2"
tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros"]}
tower-http = { version = "0.6.6", features = ["trace"] }
tracing = "0.1.41"
tracing-subscriber = "0.3.19"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
tract-onnx = "0.21.13"
utoipa = { version = "5.4.0", features = ["axum_extras"] }
utoipa-axum = "0.2.0"
utoipa-swagger-ui = { version = "9.0.2", features = ["axum", "vendored"] }
uuid = { version = "1.17.0", features = ["v4"] }
walkdir = "2.5.0"
zip = "4.2.0"

75
src/api.rs Normal file
View file

@ -0,0 +1,75 @@
use std::{
io,
net::{AddrParseError, SocketAddr},
str::FromStr,
};
use snafu::{ResultExt, Snafu};
use state::AppState;
use tokio::net::TcpListener;
use utoipa::OpenApi;
use utoipa_axum::router::OpenApiRouter;
use utoipa_swagger_ui::SwaggerUi;
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
pub mod error;
pub mod query;
pub mod routes;
pub mod state;
const TAG: &str = "little-librarian";
#[derive(OpenApi)]
#[openapi(
tags(
(name = TAG, description = "Calibre Semantic Search API")
)
)]
struct ApiDoc;
#[derive(Debug, Snafu)]
pub enum ServeError {
#[snafu(display("Failed to parse address into <ip>:<port>."))]
AddressParse { source: AddrParseError },
#[snafu(display("Failed to bind to {address}."))]
Bind {
source: io::Error,
address: SocketAddr,
},
#[snafu(display("Failed to run http server."))]
Serve { source: io::Error },
}
pub async fn serve(
address: &str,
db: Postgres,
tokenizer: Tokenizer,
embedder: TextEncoder,
reranker: TextEncoder,
chunk_size: usize,
) -> Result<(), ServeError> {
let state = AppState {
db,
tokenizer,
embedder,
reranker,
chunk_size,
};
let (router, api) = OpenApiRouter::with_openapi(ApiDoc::openapi())
.nest("/api/v1/", routes::router(state))
.split_for_parts();
let router =
router.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", api.clone()));
let address = SocketAddr::from_str(address).context(AddressParseSnafu)?;
let listener = TcpListener::bind(&address)
.await
.context(BindSnafu { address })?;
axum::serve(listener, router.into_make_service())
.await
.context(ServeSnafu)
}

32
src/api/error.rs Normal file
View file

@ -0,0 +1,32 @@
use axum::http::StatusCode;
use serde::Serialize;
#[derive(Serialize)]
pub struct ErrorResponse {
pub id: String,
pub error: String,
}
pub trait HttpStatus {
fn status_code(&self) -> StatusCode;
}
#[macro_export]
macro_rules! http_error {
($error_type:ty) => {
impl axum::response::IntoResponse for $error_type {
fn into_response(self) -> axum::response::Response {
let status = self.status_code();
let id = uuid::Uuid::new_v4().to_string();
tracing::error!("{}: {}", &id, snafu::Report::from_error(&self));
let error_response = $crate::api::error::ErrorResponse {
id,
error: self.to_string(),
};
(status, axum::Json(error_response)).into_response()
}
}
};
}

114
src/api/query.rs Normal file
View file

@ -0,0 +1,114 @@
use std::sync::Arc;
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
};
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu, ensure};
use utoipa::ToSchema;
use super::{TAG, error::HttpStatus, state::AppState};
use crate::{http_error, query, storage::DocumentMatch};
const MAX_LIMIT: usize = 10;
#[derive(Debug, Snafu)]
pub enum QueryError {
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
Limit,
#[snafu(display("Failed to run query."))]
Query { source: query::AskError },
}
impl HttpStatus for QueryError {
fn status_code(&self) -> StatusCode {
match self {
QueryError::Limit => StatusCode::BAD_REQUEST,
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
http_error!(QueryError);
#[derive(Deserialize)]
pub struct QueryParams {
pub limit: Option<usize>,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct QueryResponse {
pub results: Vec<DocumentResult>,
pub count: usize,
pub query: String,
}
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
let count = results.len();
Self {
results,
count,
query,
}
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct DocumentResult {
pub book_id: i64,
pub text_chunk: String,
pub similarity: f64,
}
impl From<DocumentMatch> for DocumentResult {
fn from(doc: DocumentMatch) -> Self {
Self {
book_id: doc.book_id,
text_chunk: doc.text_chunk,
similarity: doc.similarity,
}
}
}
#[utoipa::path(
post,
path = "/query",
tag = TAG,
params(
("limit" = Option<usize>, Query, description = "Maximum number of results")
),
request_body = String,
responses(
(status = OK, body = QueryResponse),
(status = 400, description = "Wrong parameter."),
(status = 500, description = "Failure to query database.")
)
)]
pub async fn route(
Query(params): Query<QueryParams>,
State(state): State<Arc<AppState>>,
body: String,
) -> Result<Json<QueryResponse>, QueryError> {
let limit = params.limit.unwrap_or(5);
ensure!(limit <= MAX_LIMIT, LimitSnafu);
let results = query::ask(
&body,
&state.db,
&state.tokenizer,
&state.embedder,
&state.reranker,
state.chunk_size,
limit,
)
.await
.context(QuerySnafu)?;
let response = QueryResponse::from((results, body));
Ok(Json(response))
}

15
src/api/routes.rs Normal file
View file

@ -0,0 +1,15 @@
use std::sync::Arc;
use tower_http::trace::TraceLayer;
use utoipa_axum::{router::OpenApiRouter, routes};
use super::state::AppState;
use crate::api::query;
pub fn router(state: AppState) -> OpenApiRouter {
let store = Arc::new(state);
OpenApiRouter::new()
.routes(routes!(query::route))
.layer(TraceLayer::new_for_http())
.with_state(store)
}

10
src/api/state.rs Normal file
View file

@ -0,0 +1,10 @@
use crate::{storage::Postgres, text_encoder::TextEncoder, tokenize::Tokenizer};
#[derive(Debug, Clone)]
pub struct AppState {
pub db: Postgres,
pub tokenizer: Tokenizer,
pub embedder: TextEncoder,
pub reranker: TextEncoder,
pub chunk_size: usize,
}

View file

@ -30,7 +30,7 @@ pub struct Cli {
pub command: Commands,
}
/// Configuration for the ask command.
/// Run a query.
#[derive(Args, Clone)]
pub struct AskConfig {
/// Query to run.
@ -46,7 +46,7 @@ pub struct AskConfig {
pub reranking_path: PathBuf,
}
/// Configuration for the scan command.
/// Process a calibre library.
#[derive(Args, Clone)]
pub struct ScanConfig {
/// Root directory of calibre library.
@ -54,9 +54,22 @@ pub struct ScanConfig {
pub library_path: PathBuf,
}
/// Serve the query interface as an HTTP API.
#[derive(Args, Clone)]
pub struct ServeConfig {
/// Address to listen on.
#[arg(short, long, default_value = "[::]:8080")]
pub address: String,
/// Path to reranking model.
#[arg(short, long, default_value = "./reranking.onnx")]
pub reranking_path: PathBuf,
}
/// Available CLI commands.
#[derive(Subcommand)]
pub enum Commands {
Ask(AskConfig),
Scan(ScanConfig),
Serve(ServeConfig),
}

View file

@ -1,3 +1,4 @@
pub mod api;
pub mod calibre;
pub mod cli;
pub mod extractors {

View file

@ -4,6 +4,7 @@ use std::io;
use clap::Parser;
use little_librarian::{
api,
cli::{Cli, Commands},
query,
scanner::{self},
@ -12,6 +13,7 @@ use little_librarian::{
tokenize::{self, Tokenizer},
};
use snafu::{ResultExt, Snafu};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
/// Top-level application errors.
#[derive(Debug, Snafu)]
@ -26,16 +28,33 @@ pub enum Error {
ModelLoading {
source: text_encoder::NewFromFileError,
},
#[snafu(display("Failed to load reranker model."))]
RerankerLoading {
source: text_encoder::NewFromFileError,
},
#[snafu(display("Failed to connect to database."))]
DbConnection { source: storage::ConnectionError },
#[snafu(display("Failed to process query."))]
Ask { source: query::AskError },
#[snafu(display("Failed to serve HTTP API."))]
Serve { source: api::ServeError },
}
#[tokio::main]
#[snafu::report]
async fn main() -> Result<(), Error> {
tracing_subscriber::fmt::init();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
format!(
"{}=debug,tower_http=debug,axum::rejection=trace",
env!("CARGO_CRATE_NAME")
)
.into()
}),
)
.with(tracing_subscriber::fmt::layer())
.init();
let cli = Cli::parse();
let db_url = if cli.db_url == "-" {
@ -52,12 +71,14 @@ async fn main() -> Result<(), Error> {
match &cli.command {
Commands::Ask(config) => {
let reranker =
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
query::ask(
&config.query,
&db,
&tokenizer,
&embedder,
&config.reranking_path,
&reranker,
cli.chunk_size,
config.limit,
)
@ -74,6 +95,20 @@ async fn main() -> Result<(), Error> {
)
.await;
}
Commands::Serve(config) => {
let reranker =
TextEncoder::from_file(&config.reranking_path).context(RerankerLoadingSnafu)?;
api::serve(
&config.address,
db,
tokenizer,
embedder,
reranker,
cli.chunk_size,
)
.await
.context(ServeSnafu)?;
}
}
Ok(())

View file

@ -1,11 +1,9 @@
//! Query processing and document retrieval.
use std::path::PathBuf;
use snafu::{ResultExt, Snafu};
use crate::{
storage::{self, Postgres},
storage::{self, DocumentMatch, Postgres},
text_encoder::{self, TextEncoder},
tokenize::{self, Tokenizer},
};
@ -19,10 +17,6 @@ pub enum AskError {
Embed { source: text_encoder::EmbedError },
#[snafu(display("Failed to retrieve similar documents."))]
Query { source: storage::QueryError },
#[snafu(display("Failed to load reranker model."))]
LoadReranker {
source: text_encoder::NewFromFileError,
},
#[snafu(display("Failed to rerank documents."))]
Rerank { source: text_encoder::RerankError },
}
@ -32,34 +26,21 @@ pub async fn ask(
query: &str,
db: &Postgres,
tokenizer: &Tokenizer,
text_encoder: &TextEncoder,
reranker_path: &PathBuf,
embedder: &TextEncoder,
reranker: &TextEncoder,
chunk_size: usize,
limit: usize,
) -> Result<(), AskError> {
) -> Result<Vec<DocumentMatch>, AskError> {
let encodings = tokenizer.encode(query, chunk_size).context(EncodeSnafu)?;
let embeddings = text_encoder
.embed(encodings[0].clone())
.context(EmbedSnafu)?;
let embeddings = embedder.embed(encodings[0].clone()).context(EmbedSnafu)?;
let documents = db
.query(embeddings, (limit * 10) as i32)
.await
.context(QuerySnafu)?;
let reranker = TextEncoder::from_file(reranker_path).context(LoadRerankerSnafu)?;
let reranked_docs = reranker
.rerank(query, documents, tokenizer, limit)
.context(RerankSnafu)?;
for (i, doc) in reranked_docs.iter().enumerate() {
println!(
"{}. Book ID: {}, Score: {:.3}",
i + 1,
doc.book_id,
doc.similarity
);
println!(" {}\n", doc.text_chunk);
}
Ok(())
Ok(reranked_docs)
}

View file

@ -18,6 +18,7 @@ pub struct DocumentMatch {
}
/// PostgreSQL database connection pool and operations.
#[derive(Debug, Clone)]
pub struct Postgres {
/// Connection pool for database operations.
pool: Pool<sqlx::Postgres>,

View file

@ -28,6 +28,7 @@ type Model = SimplePlan<
pub type Embeddings = (Vec<f32>, Option<usize>, String);
/// ONNX-based text embedder for generating vector representations.
#[derive(Debug, Clone)]
pub struct TextEncoder {
/// Compiled ONNX model for inference.
model: Model,

View file

@ -28,6 +28,7 @@ pub enum EncodeError {
}
/// Wrapper around a tokenizer that handles text chunking.
#[derive(Debug, Clone)]
pub struct Tokenizer {
/// The underlying tokenizer implementation.
inner: tokenizers::Tokenizer,