"""ChromaDB-based learned pattern memory for family document routing. Remembers document → family member assignments after user calibration. Overrides LLM classification on future documents. Adapts costco_route/item_memory.py: - Items → document patterns - Zones → member IDs - Removed Instacart seed data (not applicable) Sovereign: Zero imports from costco_route. """ import chromadb from icarus.core.config.staging import CHROMA_DB_PATH from icarus.core.embeddings import get_embedding CHROMA_COLLECTION = "family_assignments" def _get_collection(): """Get or create the ChromaDB collection.""" client = chromadb.PersistentClient(path=str(CHROMA_DB_PATH)) return client.get_or_create_collection( name=CHROMA_COLLECTION, metadata={"hnsw:space": "cosine"}, ) def learn_assignment(document_hash: str, member_id: str, pattern: str, notes: str = ""): """Save a learned family assignment. Args: document_hash: Unique hash of document content (for dedup) member_id: Family member ID (sully, harper, aundrea, matt) pattern: Extracted pattern that triggered assignment (for embedding) notes: Optional notes (e.g., "Sully's field trip form") """ col = _get_collection() emb = get_embedding(pattern) col.upsert( ids=[document_hash], embeddings=[emb], documents=[pattern], metadatas=[{ "member_id": member_id, "notes": notes, "source": "user_calibration" }], ) def lookup_assignment(pattern: str, threshold: float = 0.85) -> dict | None: """Check if we have a learned assignment for a document pattern. Args: pattern: Document pattern to look up (e.g., "field trip permission") threshold: Cosine similarity threshold (0-1) Returns: Dict with member_id, notes, pattern, similarity — or None """ col = _get_collection() if col.count() == 0: return None emb = get_embedding(pattern) results = col.query( query_embeddings=[emb], n_results=3, include=["documents", "metadatas", "distances"], ) if not results["ids"] or not results["ids"][0]: return None # ChromaDB cosine distance: 0 = identical, 2 = opposite # Convert to similarity: 1 - distance/2 best_match = None for i, doc_id in enumerate(results["ids"][0]): distance = results["distances"][0][i] similarity = 1.0 - (distance / 2.0) if similarity >= threshold: match = { "pattern": results["documents"][0][i], "member_id": results["metadatas"][0][i]["member_id"], "notes": results["metadatas"][0][i].get("notes", ""), "similarity": round(similarity, 3), } if best_match is None or similarity > best_match["similarity"]: best_match = match return best_match def lookup_assignments(patterns: list[str], threshold: float = 0.85) -> dict[str, dict]: """Batch lookup for multiple patterns. Returns: Dict of pattern → assignment (only patterns with matches). """ results = {} for pattern in patterns: match = lookup_assignment(pattern, threshold) if match: results[pattern] = match return results def lookup_assignment_exact(document_hash: str) -> dict | None: """Exact match lookup by document hash. Use for verifying existing calibrations. """ col = _get_collection() try: result = col.get(ids=[document_hash], include=["documents", "metadatas"]) if result["ids"]: return { "pattern": result["documents"][0], "member_id": result["metadatas"][0]["member_id"], "notes": result["metadatas"][0].get("notes", ""), } except Exception: pass return None def stats() -> dict: """Return stats about the learned assignment database.""" col = _get_collection() total = col.count() # Count by source user_count = 0 try: user_results = col.get(where={"source": "user_calibration"}, include=[]) user_count = len(user_results["ids"]) except Exception: pass return { "total_assignments": total, "user_calibrations": user_count, "collection": CHROMA_COLLECTION, "path": str(CHROMA_DB_PATH), } def purge_old_assignments(months: int = 6) -> dict: """Remove assignments older than N months. Family patterns may change over time (e.g., school changes), so old assignments may become stale. """ # ChromaDB doesn't have native date filtering on metadata easily, # so for now this is a placeholder. Could add timestamp metadata. return {"purged": 0, "remaining": stats()["total_assignments"]}