use std::sync::Arc; use axum::{ Json, extract::{Query, State}, http::StatusCode, }; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu, ensure}; use utoipa::ToSchema; use super::{TAG, error::HttpStatus, state::AppState}; use crate::{http_error, query, storage::DocumentMatch}; const MAX_LIMIT: usize = 10; #[derive(Debug, Snafu)] pub enum QueryError { #[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))] Limit, #[snafu(display("Failed to run query."))] Query { source: query::AskError }, } impl HttpStatus for QueryError { fn status_code(&self) -> StatusCode { match self { QueryError::Limit => StatusCode::BAD_REQUEST, QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR, } } } http_error!(QueryError); #[derive(Deserialize)] pub struct QueryParams { pub limit: Option, } #[derive(Debug, Serialize, ToSchema)] pub struct QueryResponse { pub results: Vec, pub count: usize, pub query: String, } impl From<(Vec, String)> for QueryResponse { fn from((documents, query): (Vec, String)) -> Self { let results: Vec = documents.into_iter().map(Into::into).collect(); let count = results.len(); Self { results, count, query, } } } #[derive(Debug, Serialize, ToSchema)] pub struct DocumentResult { pub book_id: i64, pub text_chunk: String, pub similarity: f64, } impl From for DocumentResult { fn from(doc: DocumentMatch) -> Self { Self { book_id: doc.book_id, text_chunk: doc.text_chunk, similarity: doc.similarity, } } } #[utoipa::path( post, path = "/query", tag = TAG, params( ("limit" = Option, Query, description = "Maximum number of results") ), request_body = String, responses( (status = OK, body = QueryResponse), (status = 400, description = "Wrong parameter."), (status = 500, description = "Failure to query database.") ) )] pub async fn route( Query(params): Query, State(state): State>, body: String, ) -> Result, QueryError> { let limit = params.limit.unwrap_or(5); ensure!(limit <= MAX_LIMIT, LimitSnafu); let results = query::ask( &body, &state.db, &state.tokenizer, &state.embedder, &state.reranker, state.chunk_size, limit, ) .await .context(QuerySnafu)?; let response = QueryResponse::from((results, body)); Ok(Json(response)) }