115 lines
2.7 KiB
Rust
115 lines
2.7 KiB
Rust
|
use std::sync::Arc;
|
||
|
|
||
|
use axum::{
|
||
|
Json,
|
||
|
extract::{Query, State},
|
||
|
http::StatusCode,
|
||
|
};
|
||
|
use serde::{Deserialize, Serialize};
|
||
|
use snafu::{ResultExt, Snafu, ensure};
|
||
|
use utoipa::ToSchema;
|
||
|
|
||
|
use super::{TAG, error::HttpStatus, state::AppState};
|
||
|
use crate::{http_error, query, storage::DocumentMatch};
|
||
|
|
||
|
const MAX_LIMIT: usize = 10;
|
||
|
|
||
|
#[derive(Debug, Snafu)]
|
||
|
pub enum QueryError {
|
||
|
#[snafu(display("'limit' query parameter must be a positive integer <= {MAX_LIMIT}."))]
|
||
|
Limit,
|
||
|
#[snafu(display("Failed to run query."))]
|
||
|
Query { source: query::AskError },
|
||
|
}
|
||
|
|
||
|
impl HttpStatus for QueryError {
|
||
|
fn status_code(&self) -> StatusCode {
|
||
|
match self {
|
||
|
QueryError::Limit => StatusCode::BAD_REQUEST,
|
||
|
QueryError::Query { source: _ } => StatusCode::INTERNAL_SERVER_ERROR,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
http_error!(QueryError);
|
||
|
|
||
|
#[derive(Deserialize)]
|
||
|
pub struct QueryParams {
|
||
|
pub limit: Option<usize>,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Serialize, ToSchema)]
|
||
|
pub struct QueryResponse {
|
||
|
pub results: Vec<DocumentResult>,
|
||
|
pub count: usize,
|
||
|
pub query: String,
|
||
|
}
|
||
|
|
||
|
impl From<(Vec<DocumentMatch>, String)> for QueryResponse {
|
||
|
fn from((documents, query): (Vec<DocumentMatch>, String)) -> Self {
|
||
|
let results: Vec<DocumentResult> = documents.into_iter().map(Into::into).collect();
|
||
|
let count = results.len();
|
||
|
|
||
|
Self {
|
||
|
results,
|
||
|
count,
|
||
|
query,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Serialize, ToSchema)]
|
||
|
pub struct DocumentResult {
|
||
|
pub book_id: i64,
|
||
|
pub text_chunk: String,
|
||
|
pub similarity: f64,
|
||
|
}
|
||
|
|
||
|
impl From<DocumentMatch> for DocumentResult {
|
||
|
fn from(doc: DocumentMatch) -> Self {
|
||
|
Self {
|
||
|
book_id: doc.book_id,
|
||
|
text_chunk: doc.text_chunk,
|
||
|
similarity: doc.similarity,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[utoipa::path(
|
||
|
post,
|
||
|
path = "/query",
|
||
|
tag = TAG,
|
||
|
params(
|
||
|
("limit" = Option<usize>, Query, description = "Maximum number of results")
|
||
|
),
|
||
|
request_body = String,
|
||
|
responses(
|
||
|
(status = OK, body = QueryResponse),
|
||
|
(status = 400, description = "Wrong parameter."),
|
||
|
(status = 500, description = "Failure to query database.")
|
||
|
)
|
||
|
)]
|
||
|
pub async fn route(
|
||
|
Query(params): Query<QueryParams>,
|
||
|
State(state): State<Arc<AppState>>,
|
||
|
body: String,
|
||
|
) -> Result<Json<QueryResponse>, QueryError> {
|
||
|
let limit = params.limit.unwrap_or(5);
|
||
|
ensure!(limit <= MAX_LIMIT, LimitSnafu);
|
||
|
|
||
|
let results = query::ask(
|
||
|
&body,
|
||
|
&state.db,
|
||
|
&state.tokenizer,
|
||
|
&state.embedder,
|
||
|
&state.reranker,
|
||
|
state.chunk_size,
|
||
|
limit,
|
||
|
)
|
||
|
.await
|
||
|
.context(QuerySnafu)?;
|
||
|
let response = QueryResponse::from((results, body));
|
||
|
|
||
|
Ok(Json(response))
|
||
|
}
|