"""ChromaDB-based learned pattern memory for family document routing.
Remembers document → family member assignments after user calibration.
Overrides LLM classification on future documents.
Adapts costco_route/item_memory.py:
- Items → document patterns
- Zones → member IDs
- Removed Instacart seed data (not applicable)
Sovereign: Zero imports from costco_route.
"""
import chromadb
from icarus.core.config.staging import CHROMA_DB_PATH
from icarus.core.embeddings import get_embedding
CHROMA_COLLECTION = "family_assignments"
def _get_collection():
"""Get or create the ChromaDB collection."""
client = chromadb.PersistentClient(path=str(CHROMA_DB_PATH))
return client.get_or_create_collection(
name=CHROMA_COLLECTION,
metadata={"hnsw:space": "cosine"},
)
def learn_assignment(document_hash: str, member_id: str, pattern: str, notes: str = ""):
"""Save a learned family assignment.
Args:
document_hash: Unique hash of document content (for dedup)
member_id: Family member ID (sully, harper, aundrea, matt)
pattern: Extracted pattern that triggered assignment (for embedding)
notes: Optional notes (e.g., "Sully's field trip form")
"""
col = _get_collection()
emb = get_embedding(pattern)
col.upsert(
ids=[document_hash],
embeddings=[emb],
documents=[pattern],
metadatas=[{
"member_id": member_id,
"notes": notes,
"source": "user_calibration"
}],
)
def lookup_assignment(pattern: str, threshold: float = 0.85) -> dict | None:
"""Check if we have a learned assignment for a document pattern.
Args:
pattern: Document pattern to look up (e.g., "field trip permission")
threshold: Cosine similarity threshold (0-1)
Returns:
Dict with member_id, notes, pattern, similarity — or None
"""
col = _get_collection()
if col.count() == 0:
return None
emb = get_embedding(pattern)
results = col.query(
query_embeddings=[emb],
n_results=3,
include=["documents", "metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return None
# ChromaDB cosine distance: 0 = identical, 2 = opposite
# Convert to similarity: 1 - distance/2
best_match = None
for i, doc_id in enumerate(results["ids"][0]):
distance = results["distances"][0][i]
similarity = 1.0 - (distance / 2.0)
if similarity >= threshold:
match = {
"pattern": results["documents"][0][i],
"member_id": results["metadatas"][0][i]["member_id"],
"notes": results["metadatas"][0][i].get("notes", ""),
"similarity": round(similarity, 3),
}
if best_match is None or similarity > best_match["similarity"]:
best_match = match
return best_match
def lookup_assignments(patterns: list[str], threshold: float = 0.85) -> dict[str, dict]:
"""Batch lookup for multiple patterns.
Returns:
Dict of pattern → assignment (only patterns with matches).
"""
results = {}
for pattern in patterns:
match = lookup_assignment(pattern, threshold)
if match:
results[pattern] = match
return results
def lookup_assignment_exact(document_hash: str) -> dict | None:
"""Exact match lookup by document hash.
Use for verifying existing calibrations.
"""
col = _get_collection()
try:
result = col.get(ids=[document_hash], include=["documents", "metadatas"])
if result["ids"]:
return {
"pattern": result["documents"][0],
"member_id": result["metadatas"][0]["member_id"],
"notes": result["metadatas"][0].get("notes", ""),
}
except Exception:
pass
return None
def stats() -> dict:
"""Return stats about the learned assignment database."""
col = _get_collection()
total = col.count()
# Count by source
user_count = 0
try:
user_results = col.get(where={"source": "user_calibration"}, include=[])
user_count = len(user_results["ids"])
except Exception:
pass
return {
"total_assignments": total,
"user_calibrations": user_count,
"collection": CHROMA_COLLECTION,
"path": str(CHROMA_DB_PATH),
}
def purge_old_assignments(months: int = 6) -> dict:
"""Remove assignments older than N months.
Family patterns may change over time (e.g., school changes),
so old assignments may become stale.
"""
# ChromaDB doesn't have native date filtering on metadata easily,
# so for now this is a placeholder. Could add timestamp metadata.
return {"purged": 0, "remaining": stats()["total_assignments"]}