"""ChromaDB-based learned item location store. Remembers where items actually are after user calibration. Overrides LLM classification on future trips. Two-tier matching: - User calibrations (source=user_calibration): standard threshold (0.85) - Instacart seed data (source=instacart): higher threshold (0.93) to prevent false matches from the 50K product database """ import chromadb from costco_route.config import CHROMA_PATH, CHROMA_COLLECTION, EMBED_MODEL from costco_route.llm_client import get_embedding def _get_collection(): """Get or create the ChromaDB collection.""" client = chromadb.PersistentClient(path=CHROMA_PATH) return client.get_or_create_collection( name=CHROMA_COLLECTION, metadata={"hnsw:space": "cosine"}, ) def learn_item(item: str, zone_id: str, notes: str = ""): """Save a learned item location. Overwrites previous entry for same item. Args: item: Item name (e.g., "Kirkland Organic Milk") zone_id: Zone ID (01-10) notes: Optional notes (e.g., "back corner by bakery") """ col = _get_collection() emb = get_embedding(item) doc_id = item.lower().strip().replace(" ", "_")[:64] col.upsert( ids=[doc_id], embeddings=[emb], documents=[item], metadatas=[{"zone": zone_id, "notes": notes, "source": "user_calibration"}], ) def lookup_item(item: str, threshold: float = 0.85) -> dict | None: """Check if we have a learned location for an item. Two-tier matching: - User calibrations (source=user_calibration): use the standard threshold - Instacart seed data (source=instacart): require higher similarity (0.93+) to prevent false matches from the 50K product database Args: item: Item to look up threshold: Cosine similarity threshold for user calibrations (0-1). Returns: Dict with zone, notes, item match — or None if no match found. """ col = _get_collection() if col.count() == 0: return None emb = get_embedding(item) results = col.query( query_embeddings=[emb], n_results=5, include=["documents", "metadatas", "distances"], ) if not results["ids"] or not results["ids"][0]: return None # ChromaDB cosine distance: 0 = identical, 2 = opposite # Convert to similarity: 1 - distance/2 # # Priority: user_calibration > instacart seed # Instacart matches need higher similarity to avoid false positives best_user_match = None best_seed_match = None for i, doc_id in enumerate(results["ids"][0]): distance = results["distances"][0][i] similarity = 1.0 - (distance / 2.0) source = results["metadatas"][0][i].get("source", "unknown") if source == "user_calibration" and similarity >= threshold: if best_user_match is None or similarity > best_user_match["similarity"]: best_user_match = { "item": results["documents"][0][i], "zone": results["metadatas"][0][i]["zone"], "notes": results["metadatas"][0][i].get("notes", ""), "similarity": round(similarity, 3), "source": source, } elif source == "instacart" and similarity >= 0.93: if best_seed_match is None or similarity > best_seed_match["similarity"]: best_seed_match = { "item": results["documents"][0][i], "zone": results["metadatas"][0][i]["zone"], "notes": results["metadatas"][0][i].get("notes", ""), "similarity": round(similarity, 3), "source": source, } # User calibrations always win return best_user_match or best_seed_match def lookup_items(items: list[str], threshold: float = 0.85) -> dict[str, dict]: """Batch lookup for multiple items. Returns: Dict of item → lookup result (only items with matches). """ results = {} for item in items: match = lookup_item(item, threshold) if match: results[item] = match return results def lookup_item_exact(item: str) -> dict | None: """Exact match lookup — only returns user calibrations, not seed data. Use for the /learn command to verify existing calibrations. """ col = _get_collection() doc_id = item.lower().strip().replace(" ", "_")[:64] try: result = col.get(ids=[doc_id], include=["documents", "metadatas"]) if result["ids"]: return { "item": result["documents"][0], "zone": result["metadatas"][0]["zone"], "notes": result["metadatas"][0].get("notes", ""), } except Exception: pass return None def stats() -> dict: """Return stats about the learned item database.""" col = _get_collection() total = col.count() # Count by source user_count = 0 seed_count = 0 try: user_results = col.get(where={"source": "user_calibration"}, include=[]) user_count = len(user_results["ids"]) seed_results = col.get(where={"source": "instacart"}, include=[]) seed_count = len(seed_results["ids"]) except Exception: pass return { "total_items": total, "user_calibrations": user_count, "instacart_seeded": seed_count, "collection": CHROMA_COLLECTION, "path": CHROMA_PATH, } def purge_old(months: int = 6) -> dict: """Remove entries older than N months. Costco moves products around, so old calibrations may be stale. """ # ChromaDB doesn't have native date filtering on metadata easily, # so for now this is a placeholder. Could add timestamp metadata. return {"purged": 0, "remaining": stats()["total_items"]}