initial commit
This commit is contained in:
commit
0bd97d0ed3
25 changed files with 6581 additions and 0 deletions
224
src/storage.rs
Normal file
224
src/storage.rs
Normal file
|
@ -0,0 +1,224 @@
|
|||
//! Database storage and retrieval for document embeddings.
|
||||
|
||||
use futures::future::try_join_all;
|
||||
use snafu::{ResultExt, Snafu};
|
||||
use sqlx::{FromRow, Pool, postgres::PgPoolOptions};
|
||||
|
||||
use crate::{hash::SHA256_LENGTH, text_encoder::Embeddings};
|
||||
|
||||
/// A document chunk with similarity score from vector search.
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct DocumentMatch {
|
||||
/// Calibre book ID.
|
||||
pub book_id: i64,
|
||||
/// Text content of the chunk.
|
||||
pub text_chunk: String,
|
||||
/// Cosine similarity score (0.0 to 1.0).
|
||||
pub similarity: f64,
|
||||
}
|
||||
|
||||
/// PostgreSQL database connection pool and operations.
|
||||
pub struct Postgres {
|
||||
/// Connection pool for database operations.
|
||||
pool: Pool<sqlx::Postgres>,
|
||||
}
|
||||
|
||||
/// Error when connecting to database.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to connect to database."))]
|
||||
pub struct ConnectionError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when retrieving document version.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to get highest document version."))]
|
||||
pub struct VersionError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when storing embeddings.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to store embeddings."))]
|
||||
pub struct EmbeddingError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when querying for similar documents.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to query similarity data."))]
|
||||
pub struct QueryError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when updating document hash.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to update document hash."))]
|
||||
pub struct UpdateHashError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when checking if document needs update.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to check if document needs update."))]
|
||||
pub struct CheckUpdateError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
/// Error when removing old document versions.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(display("Failed to delete old document versions."))]
|
||||
pub struct RemoveOldVersionsError {
|
||||
source: sqlx::Error,
|
||||
}
|
||||
|
||||
impl Postgres {
|
||||
/// Create a new database connection pool.
|
||||
pub async fn connect(db_url: &str, max_connections: u32) -> Result<Self, ConnectionError> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(db_url)
|
||||
.await
|
||||
.context(ConnectionSnafu)?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
/// Get the highest version number for a book's embeddings.
|
||||
pub async fn embedding_version(&self, book_id: i64) -> Result<i32, VersionError> {
|
||||
let version = sqlx::query_scalar(
|
||||
"SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1",
|
||||
)
|
||||
.bind(book_id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.context(VersionSnafu)?;
|
||||
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
/// Store multiple embeddings for a book in parallel.
|
||||
pub async fn add_multiple_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Vec<Embeddings>,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let inserts: Vec<_> = embeddings
|
||||
.into_iter()
|
||||
.map(|e| self.add_embeddings(book_id, version, e))
|
||||
.collect();
|
||||
|
||||
try_join_all(inserts).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store a single embedding for a book.
|
||||
pub async fn add_embeddings(
|
||||
&self,
|
||||
book_id: i64,
|
||||
version: i32,
|
||||
embeddings: Embeddings,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
let vector_string = embeddings_to_vec_string(embeddings.0);
|
||||
sqlx::query(
|
||||
"INSERT INTO documents (book_id, version, embedding, text_chunk) VALUES ($1, $2, $3::vector, $4)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(version)
|
||||
.bind(vector_string)
|
||||
.bind(embeddings.2)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(EmbeddingSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find documents similar to the query embedding using cosine similarity.
|
||||
pub async fn query(
|
||||
&self,
|
||||
query: Embeddings,
|
||||
limit: i32,
|
||||
) -> Result<Vec<DocumentMatch>, QueryError> {
|
||||
let vector_string = embeddings_to_vec_string(query.0);
|
||||
let books = sqlx::query_as::<_, DocumentMatch>(
|
||||
"SELECT book_id, text_chunk, 1 - (embedding <=> $1::vector) as similarity
|
||||
FROM documents
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $2",
|
||||
)
|
||||
.bind(vector_string)
|
||||
.bind(limit)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.context(QuerySnafu)?;
|
||||
|
||||
Ok(books)
|
||||
}
|
||||
|
||||
/// Update or insert the hash for a book.
|
||||
pub async fn update_hash(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<(), UpdateHashError> {
|
||||
sqlx::query(
|
||||
"INSERT INTO hashes (book_id, hash) VALUES ($1, $2)
|
||||
ON CONFLICT (book_id) DO UPDATE SET hash = $2",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(UpdateHashSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a book needs to be reprocessed based on its hash.
|
||||
pub async fn book_needs_update(
|
||||
&self,
|
||||
book_id: i64,
|
||||
hash: [u8; SHA256_LENGTH],
|
||||
) -> Result<bool, CheckUpdateError> {
|
||||
let exists = sqlx::query_scalar::<_, bool>(
|
||||
"SELECT NOT EXISTS (
|
||||
SELECT 1 FROM hashes
|
||||
WHERE book_id = $1 AND hash = $2
|
||||
)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.bind(&hash[..])
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.context(CheckUpdateSnafu)?;
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
/// Remove all but the latest version of embeddings for a book.
|
||||
pub async fn remove_old_versions(&self, book_id: i64) -> Result<(), RemoveOldVersionsError> {
|
||||
sqlx::query(
|
||||
"DELETE FROM documents WHERE version < (SELECT COALESCE(MAX(version), 0) FROM documents WHERE book_id = $1)",
|
||||
)
|
||||
.bind(book_id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.context(RemoveOldVersionsSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a vector of floats to PostgreSQL vector format string.
|
||||
fn embeddings_to_vec_string(vector: Vec<f32>) -> String {
|
||||
format!(
|
||||
"[{}]",
|
||||
vector
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(",")
|
||||
)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue