66 lines
1.8 KiB
Rust
66 lines
1.8 KiB
Rust
|
//! Query processing and document retrieval.
|
||
|
|
||
|
use std::path::PathBuf;
|
||
|
|
||
|
use snafu::{ResultExt, Snafu};
|
||
|
|
||
|
use crate::{
|
||
|
storage::{self, Postgres},
|
||
|
text_encoder::{self, TextEncoder},
|
||
|
tokenize::{self, Tokenizer},
|
||
|
};
|
||
|
|
||
|
/// Errors that occur during query processing.
|
||
|
#[derive(Debug, Snafu)]
|
||
|
pub enum AskError {
|
||
|
#[snafu(display("Failed to encode query."))]
|
||
|
Encode { source: tokenize::EncodeError },
|
||
|
#[snafu(display("Failed to embed query."))]
|
||
|
Embed { source: text_encoder::EmbedError },
|
||
|
#[snafu(display("Failed to retrieve similar documents."))]
|
||
|
Query { source: storage::QueryError },
|
||
|
#[snafu(display("Failed to load reranker model."))]
|
||
|
LoadReranker {
|
||
|
source: text_encoder::NewFromFileError,
|
||
|
},
|
||
|
#[snafu(display("Failed to rerank documents."))]
|
||
|
Rerank { source: text_encoder::RerankError },
|
||
|
}
|
||
|
|
||
|
/// Process a user query and return ranked document matches.
|
||
|
pub async fn ask(
|
||
|
query: &str,
|
||
|
db: &Postgres,
|
||
|
tokenizer: &Tokenizer,
|
||
|
text_encoder: &TextEncoder,
|
||
|
reranker_path: &PathBuf,
|
||
|
chunk_size: usize,
|
||
|
limit: usize,
|
||
|
) -> Result<(), AskError> {
|
||
|
let encodings = tokenizer.encode(query, chunk_size).context(EncodeSnafu)?;
|
||
|
let embeddings = text_encoder
|
||
|
.embed(encodings[0].clone())
|
||
|
.context(EmbedSnafu)?;
|
||
|
let documents = db
|
||
|
.query(embeddings, (limit * 10) as i32)
|
||
|
.await
|
||
|
.context(QuerySnafu)?;
|
||
|
|
||
|
let reranker = TextEncoder::from_file(reranker_path).context(LoadRerankerSnafu)?;
|
||
|
let reranked_docs = reranker
|
||
|
.rerank(query, documents, tokenizer, limit)
|
||
|
.context(RerankSnafu)?;
|
||
|
|
||
|
for (i, doc) in reranked_docs.iter().enumerate() {
|
||
|
println!(
|
||
|
"{}. Book ID: {}, Score: {:.3}",
|
||
|
i + 1,
|
||
|
doc.book_id,
|
||
|
doc.similarity
|
||
|
);
|
||
|
println!(" {}\n", doc.text_chunk);
|
||
|
}
|
||
|
|
||
|
Ok(())
|
||
|
}
|