"""ChromaDB vector store for document search."""
import logging
import os
from datetime import datetime
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
from config import IS_STAGING
from brain.query import QueryPattern, calculate_temporal_score, calculate_confidence, format_answer
logger = logging.getLogger(name)
ChromaDB setup - New API (ChromaDB 1.x)
PERSIST_DIR = "./data/chroma_db" if IS_STAGING else "./data/chroma_db_prod"
os.makedirs(PERSIST_DIR, exist_ok=True)
Initialize client at module level
try:
chroma_client = chromadb.PersistentClient(
path=PERSIST_DIR
)
collection = chroma_client.get_or_create_collection("family_documents")
logger.info(f"ChromaDB initialized with collection 'family_documents' at {PERSIST_DIR}")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
chroma_client = None
collection = None
def get_collection():
"""Get the ChromaDB collection, initializing if needed."""
global collection
if collection is None:
if chroma_client is None:
raise RuntimeError("ChromaDB client not initialized")
try:
collection = chroma_client.get_or_create_collection("family_documents")
except Exception as e:
logger.error(f"Failed to get ChromaDB collection: {e}")
raise RuntimeError("ChromaDB not available")
return collection
def store_document(doc_id: str, chunks: List[Dict]):
"""Store document chunks in ChromaDB.
Args:
doc_id: Document identifier
chunks: List of chunk dictionaries with embeddings
"""
coll = get_collection()
if not chunks:
logger.warning(f"No chunks to store for document {doc_id}")
return
try:
coll.add(
documents=[c["text"] for c in chunks],
embeddings=[c["embedding"] for c in chunks],
metadatas=[c["metadata"] for c in chunks],
ids=[c["id"] for c in chunks]
)
logger.info(f"Stored {len(chunks)} chunks for document {doc_id}")
except Exception as e:
logger.error(f"Failed to store document {doc_id}: {e}")
raise
def delete_document(doc_id: str):
"""Delete all chunks for a document.
Args:
doc_id: Document identifier
"""
coll = get_collection()
try:
# Find all chunks for this document
results = coll.get(
where={"doc_id": doc_id}
)
if results and results["ids"]:
coll.delete(ids=results["ids"])
logger.info(f"Deleted {len(results['ids'])} chunks for document {doc_id}")
except Exception as e:
logger.error(f"Failed to delete document {doc_id}: {e}")
def search_documents(
query_embedding: List[float],
pattern: QueryPattern,
n_results: int = 10,
now: Optional[datetime] = None
) -> List[Dict]:
"""Search documents with temporal re-ranking.
Args:
query_embedding: Query embedding vector
pattern: Query pattern for temporal weighting
n_results: Number of results to return
now: Optional reference date
Returns:
List of scored results sorted by combined score
"""
coll = get_collection()
if now is None:
now = datetime.now()
try:
# Initial semantic search — get more for re-ranking
results = coll.query(
query_embeddings=[query_embedding],
n_results=n_results * 2
)
if not results["ids"] or not results["ids"][0]:
logger.info("No results found in semantic search")
return []
# Re-rank by temporal score
scored_results = []
for i in range(len(results["ids"][0])):
doc_id = results["ids"][0][i]
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
distance = results["distances"][0][i] if results["distances"] else 1.0
text = results["documents"][0][i] if results["documents"] else ""
# Parse source date
source_date_str = metadata.get("source_date", "2024-01-01")
try:
doc_date = datetime.fromisoformat(source_date_str)
except (ValueError, TypeError):
doc_date = datetime(2024, 1, 1)
doc_type = metadata.get("doc_type", "static_pdf")
# Calculate temporal score
temporal_score = calculate_temporal_score(doc_date, pattern, doc_type, now)
# Convert distance to similarity
# ChromaDB uses cosine distance by default, where smaller = more similar
# distance of 0 = identical, distance of 2 = opposite
# Convert to [0, 1] similarity where 1 = identical
semantic_score = 1.0 - (distance / 2.0)
semantic_score = max(0.0, min(1.0, semantic_score))
# Combined score: 70% semantic, 30% temporal
combined = (semantic_score * 0.7) + (temporal_score * 0.3)
scored_results.append({
"id": doc_id,
"text": text,
"metadata": metadata,
"semantic_score": semantic_score,
"temporal_score": temporal_score,
"combined_score": combined
})
# Sort by combined score
scored_results.sort(key=lambda x: x["combined_score"], reverse=True)
return scored_results[:n_results]
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def get_document_count() -> int:
"""Get total number of chunks in the collection."""
coll = get_collection()
try:
return coll.count()
except Exception as e:
logger.error(f"Failed to get document count: {e}")
return 0
def get_stats() -> Dict:
"""Get collection statistics."""
try:
coll = get_collection()
count = coll.count()
return {
"total_chunks": count,
"collection_name": "family_documents",
"persist_directory": PERSIST_DIR
}
except Exception as e:
logger.error(f"Failed to get stats: {e}")
return {"error": str(e), "total_chunks": 0, "collection_name": "family_documents", "persist_directory": PERSIST_DIR}