"""ChromaDB vector store for document search.""" import logging import os from datetime import datetime from typing import List, Dict, Optional import chromadb from chromadb.config import Settings from config import IS_STAGING from brain.query import QueryPattern, calculate_temporal_score, calculate_confidence, format_answer logger = logging.getLogger(__name__) # ChromaDB setup - New API (ChromaDB 1.x) PERSIST_DIR = "./data/chroma_db" if IS_STAGING else "./data/chroma_db_prod" os.makedirs(PERSIST_DIR, exist_ok=True) # Initialize client at module level try: chroma_client = chromadb.PersistentClient( path=PERSIST_DIR ) collection = chroma_client.get_or_create_collection("family_documents") logger.info(f"ChromaDB initialized with collection 'family_documents' at {PERSIST_DIR}") except Exception as e: logger.error(f"Failed to initialize ChromaDB: {e}") chroma_client = None collection = None def get_collection(): """Get the ChromaDB collection, initializing if needed.""" global collection if collection is None: if chroma_client is None: raise RuntimeError("ChromaDB client not initialized") try: collection = chroma_client.get_or_create_collection("family_documents") except Exception as e: logger.error(f"Failed to get ChromaDB collection: {e}") raise RuntimeError("ChromaDB not available") return collection def store_document(doc_id: str, chunks: List[Dict]): """Store document chunks in ChromaDB. Args: doc_id: Document identifier chunks: List of chunk dictionaries with embeddings """ coll = get_collection() if not chunks: logger.warning(f"No chunks to store for document {doc_id}") return try: coll.add( documents=[c["text"] for c in chunks], embeddings=[c["embedding"] for c in chunks], metadatas=[c["metadata"] for c in chunks], ids=[c["id"] for c in chunks] ) logger.info(f"Stored {len(chunks)} chunks for document {doc_id}") except Exception as e: logger.error(f"Failed to store document {doc_id}: {e}") raise def delete_document(doc_id: str): """Delete all chunks for a document. Args: doc_id: Document identifier """ coll = get_collection() try: # Find all chunks for this document results = coll.get( where={"doc_id": doc_id} ) if results and results["ids"]: coll.delete(ids=results["ids"]) logger.info(f"Deleted {len(results['ids'])} chunks for document {doc_id}") except Exception as e: logger.error(f"Failed to delete document {doc_id}: {e}") def search_documents( query_embedding: List[float], pattern: QueryPattern, n_results: int = 10, now: Optional[datetime] = None ) -> List[Dict]: """Search documents with temporal re-ranking. Args: query_embedding: Query embedding vector pattern: Query pattern for temporal weighting n_results: Number of results to return now: Optional reference date Returns: List of scored results sorted by combined score """ coll = get_collection() if now is None: now = datetime.now() try: # Initial semantic search — get more for re-ranking results = coll.query( query_embeddings=[query_embedding], n_results=n_results * 2 ) if not results["ids"] or not results["ids"][0]: logger.info("No results found in semantic search") return [] # Re-rank by temporal score scored_results = [] for i in range(len(results["ids"][0])): doc_id = results["ids"][0][i] metadata = results["metadatas"][0][i] if results["metadatas"] else {} distance = results["distances"][0][i] if results["distances"] else 1.0 text = results["documents"][0][i] if results["documents"] else "" # Parse source date source_date_str = metadata.get("source_date", "2024-01-01") try: doc_date = datetime.fromisoformat(source_date_str) except (ValueError, TypeError): doc_date = datetime(2024, 1, 1) doc_type = metadata.get("doc_type", "static_pdf") # Calculate temporal score temporal_score = calculate_temporal_score(doc_date, pattern, doc_type, now) # Convert distance to similarity # ChromaDB uses cosine distance by default, where smaller = more similar # distance of 0 = identical, distance of 2 = opposite # Convert to [0, 1] similarity where 1 = identical semantic_score = 1.0 - (distance / 2.0) semantic_score = max(0.0, min(1.0, semantic_score)) # Combined score: 70% semantic, 30% temporal combined = (semantic_score * 0.7) + (temporal_score * 0.3) scored_results.append({ "id": doc_id, "text": text, "metadata": metadata, "semantic_score": semantic_score, "temporal_score": temporal_score, "combined_score": combined }) # Sort by combined score scored_results.sort(key=lambda x: x["combined_score"], reverse=True) return scored_results[:n_results] except Exception as e: logger.error(f"Search failed: {e}") return [] def get_document_count() -> int: """Get total number of chunks in the collection.""" coll = get_collection() try: return coll.count() except Exception as e: logger.error(f"Failed to get document count: {e}") return 0 def get_stats() -> Dict: """Get collection statistics.""" try: coll = get_collection() count = coll.count() return { "total_chunks": count, "collection_name": "family_documents", "persist_directory": PERSIST_DIR } except Exception as e: logger.error(f"Failed to get stats: {e}") return {"error": str(e), "total_chunks": 0, "collection_name": "family_documents", "persist_directory": PERSIST_DIR}