📄 store.py 6,508 bytes Apr 30, 2026 📋 Raw

"""ChromaDB vector store for document search."""

import logging
import os
from datetime import datetime
from typing import List, Dict, Optional

import chromadb
from chromadb.config import Settings

from config import IS_STAGING
from brain.query import QueryPattern, calculate_temporal_score, calculate_confidence, format_answer

logger = logging.getLogger(name)

ChromaDB setup - New API (ChromaDB 1.x)

PERSIST_DIR = "./data/chroma_db" if IS_STAGING else "./data/chroma_db_prod"
os.makedirs(PERSIST_DIR, exist_ok=True)

Initialize client at module level

try:
chroma_client = chromadb.PersistentClient(
path=PERSIST_DIR
)
collection = chroma_client.get_or_create_collection("family_documents")
logger.info(f"ChromaDB initialized with collection 'family_documents' at {PERSIST_DIR}")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
chroma_client = None
collection = None

def get_collection():
"""Get the ChromaDB collection, initializing if needed."""
global collection
if collection is None:
if chroma_client is None:
raise RuntimeError("ChromaDB client not initialized")
try:
collection = chroma_client.get_or_create_collection("family_documents")
except Exception as e:
logger.error(f"Failed to get ChromaDB collection: {e}")
raise RuntimeError("ChromaDB not available")
return collection

def store_document(doc_id: str, chunks: List[Dict]):
"""Store document chunks in ChromaDB.

Args:
    doc_id: Document identifier
    chunks: List of chunk dictionaries with embeddings
"""
coll = get_collection()

if not chunks:
    logger.warning(f"No chunks to store for document {doc_id}")
    return

try:
    coll.add(
        documents=[c["text"] for c in chunks],
        embeddings=[c["embedding"] for c in chunks],
        metadatas=[c["metadata"] for c in chunks],
        ids=[c["id"] for c in chunks]
    )
    logger.info(f"Stored {len(chunks)} chunks for document {doc_id}")
except Exception as e:
    logger.error(f"Failed to store document {doc_id}: {e}")
    raise

def delete_document(doc_id: str):
"""Delete all chunks for a document.

Args:
    doc_id: Document identifier
"""
coll = get_collection()

try:
    # Find all chunks for this document
    results = coll.get(
        where={"doc_id": doc_id}
    )

    if results and results["ids"]:
        coll.delete(ids=results["ids"])
        logger.info(f"Deleted {len(results['ids'])} chunks for document {doc_id}")
except Exception as e:
    logger.error(f"Failed to delete document {doc_id}: {e}")

def search_documents(
query_embedding: List[float],
pattern: QueryPattern,
n_results: int = 10,
now: Optional[datetime] = None
) -> List[Dict]:
"""Search documents with temporal re-ranking.

Args:
    query_embedding: Query embedding vector
    pattern: Query pattern for temporal weighting
    n_results: Number of results to return
    now: Optional reference date

Returns:
    List of scored results sorted by combined score
"""
coll = get_collection()

if now is None:
    now = datetime.now()

try:
    # Initial semantic search — get more for re-ranking
    results = coll.query(
        query_embeddings=[query_embedding],
        n_results=n_results * 2
    )

    if not results["ids"] or not results["ids"][0]:
        logger.info("No results found in semantic search")
        return []

    # Re-rank by temporal score
    scored_results = []

    for i in range(len(results["ids"][0])):
        doc_id = results["ids"][0][i]
        metadata = results["metadatas"][0][i] if results["metadatas"] else {}
        distance = results["distances"][0][i] if results["distances"] else 1.0
        text = results["documents"][0][i] if results["documents"] else ""

        # Parse source date
        source_date_str = metadata.get("source_date", "2024-01-01")
        try:
            doc_date = datetime.fromisoformat(source_date_str)
        except (ValueError, TypeError):
            doc_date = datetime(2024, 1, 1)

        doc_type = metadata.get("doc_type", "static_pdf")

        # Calculate temporal score
        temporal_score = calculate_temporal_score(doc_date, pattern, doc_type, now)

        # Convert distance to similarity
        # ChromaDB uses cosine distance by default, where smaller = more similar
        # distance of 0 = identical, distance of 2 = opposite
        # Convert to [0, 1] similarity where 1 = identical
        semantic_score = 1.0 - (distance / 2.0)
        semantic_score = max(0.0, min(1.0, semantic_score))

        # Combined score: 70% semantic, 30% temporal
        combined = (semantic_score * 0.7) + (temporal_score * 0.3)

        scored_results.append({
            "id": doc_id,
            "text": text,
            "metadata": metadata,
            "semantic_score": semantic_score,
            "temporal_score": temporal_score,
            "combined_score": combined
        })

    # Sort by combined score
    scored_results.sort(key=lambda x: x["combined_score"], reverse=True)

    return scored_results[:n_results]

except Exception as e:
    logger.error(f"Search failed: {e}")
    return []

def get_document_count() -> int:
"""Get total number of chunks in the collection."""
coll = get_collection()
try:
return coll.count()
except Exception as e:
logger.error(f"Failed to get document count: {e}")
return 0

def get_stats() -> Dict:
"""Get collection statistics."""
try:
coll = get_collection()
count = coll.count()
return {
"total_chunks": count,
"collection_name": "family_documents",
"persist_directory": PERSIST_DIR
}
except Exception as e:
logger.error(f"Failed to get stats: {e}")
return {"error": str(e), "total_chunks": 0, "collection_name": "family_documents", "persist_directory": PERSIST_DIR}