224 lines
6.2 KiB
Rust
224 lines
6.2 KiB
Rust
//! Database storage and retrieval for document embeddings.
|
|
|
|
use futures::future::try_join_all;
|
|
use snafu::{ResultExt, Snafu};
|
|
use sqlx::{FromRow, Pool, postgres::PgPoolOptions};
|
|
|
|
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
|
|
|
/// A document chunk with similarity score from vector search.
|
|
#[derive(Debug, FromRow)]
|
|
pub struct DocumentMatch {
|
|
/// Calibre book ID.
|
|
pub book_id: i64,
|
|
/// Text content of the chunk.
|
|
pub text_chunk: String,
|
|
/// Cosine similarity score (0.0 to 1.0).
|
|
pub similarity: f64,
|
|
}
|
|
|
|
/// PostgreSQL database connection pool and operations.
|
|
pub struct Postgres {
|
|
/// Connection pool for database operations.
|
|
pool: Pool<sqlx::Postgres>,
|
|
}
|
|
|
|
/// Error when connecting to database.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to connect to database."))]
|
|
pub struct ConnectionError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when retrieving document version.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to get highest document version."))]
|
|
pub struct VersionError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when storing embeddings.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to store embeddings."))]
|
|
pub struct EmbeddingError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when querying for similar documents.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to query similarity data."))]
|
|
pub struct QueryError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when updating document hash.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to update document hash."))]
|
|
pub struct UpdateHashError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when checking if document needs update.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to check if document needs update."))]
|
|
pub struct CheckUpdateError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
/// Error when removing old document versions.
|
|
#[derive(Debug, Snafu)]
|
|
#[snafu(display("Failed to delete old document versions."))]
|
|
pub struct RemoveOldVersionsError {
|
|
source: sqlx::Error,
|
|
}
|
|
|
|
impl Postgres {
|
|
/// Create a new database connection pool.
|
|
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
|
|
let pool = PgPoolOptions::new()
|
|
.max_connections(max_connections)
|
|
.connect(db_url)
|
|
.await
|
|
.context(ConnectionSnafu)?;
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
|
|
/// Get the highest version number for a book's embeddings.
|
|
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
|
let version = sqlx::query_scalar(
|
|
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
|
)
|
|
.bind(book_id)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.context(VersionSnafu)?;
|
|
|
|
Ok(version)
|
|
}
|
|
|
|
/// Store multiple embeddings for a book in parallel.
|
|
pub async fn add_multiple_embeddings(
|
|
&self,
|
|
book_id: i64,
|
|
version: i32,
|
|
embeddings: Vec<Embeddings>,
|
|
) -> Result<(), EmbeddingError> {
|
|
let inserts: Vec<_> = embeddings
|
|
.into_iter()
|
|
.map(|e| self.add_embeddings(book_id, version, e))
|
|
.collect();
|
|
|
|
try_join_all(inserts).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Store a single embedding for a book.
|
|
pub async fn add_embeddings(
|
|
&self,
|
|
book_id: i64,
|
|
version: i32,
|
|
embeddings: Embeddings,
|
|
) -> Result<(), EmbeddingError> {
|
|
let vector_string = embeddings_to_vec_string(embeddings.0);
|
|
sqlx::query(
|
|
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
|
)
|
|
.bind(book_id)
|
|
.bind(version)
|
|
.bind(vector_string)
|
|
.bind(embeddings.2)
|
|
.execute(&self.pool)
|
|
.await
|
|
.context(EmbeddingSnafu)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Find documents similar to the query embedding using cosine similarity.
|
|
pub async fn query(
|
|
&self,
|
|
query: Embeddings,
|
|
limit: i32,
|
|
) -> Result<Vec<DocumentMatch>, QueryError> {
|
|
let vector_string = embeddings_to_vec_string(query.0);
|
|
let books = sqlx::query_as::<_, DocumentMatch>(
|
|
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
|
FROM documents
|
|
ORDER BY embedding <=> $1::vector
|
|
LIMIT $2",
|
|
)
|
|
.bind(vector_string)
|
|
.bind(limit)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.context(QuerySnafu)?;
|
|
|
|
Ok(books)
|
|
}
|
|
|
|
/// Update or insert the hash for a book.
|
|
pub async fn update_hash(
|
|
&self,
|
|
book_id: i64,
|
|
hash: [u8; SHA256_LENGTH],
|
|
) -> Result<(), UpdateHashError> {
|
|
sqlx::query(
|
|
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
|
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
|
)
|
|
.bind(book_id)
|
|
.bind(&hash[..])
|
|
.execute(&self.pool)
|
|
.await
|
|
.context(UpdateHashSnafu)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if a book needs to be reprocessed based on its hash.
|
|
pub async fn book_needs_update(
|
|
&self,
|
|
book_id: i64,
|
|
hash: [u8; SHA256_LENGTH],
|
|
) -> Result<bool, CheckUpdateError> {
|
|
let exists = sqlx::query_scalar::<_, bool>(
|
|
"SELECT NOT EXISTS (
|
|
SELECT 1 FROM hashes
|
|
WHERE book_id = $1 AND hash = $2
|
|
)",
|
|
)
|
|
.bind(book_id)
|
|
.bind(&hash[..])
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.context(CheckUpdateSnafu)?;
|
|
|
|
Ok(exists)
|
|
}
|
|
|
|
/// Remove all but the latest version of embeddings for a book.
|
|
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
|
sqlx::query(
|
|
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
|
)
|
|
.bind(book_id)
|
|
.execute(&self.pool)
|
|
.await
|
|
.context(RemoveOldVersionsSnafu)?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Convert a vector of floats to PostgreSQL vector format string.
|
|
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
|
format!(
|
|
"[{}]",
|
|
vector
|
|
.iter()
|
|
.map(|x| x.to_string())
|
|
.collect::<Vec<String>>()
|
|
.join(",")
|
|
)
|
|
}
|