"""ChromaDB-based learned item location store.
Remembers where items actually are after user calibration.
Overrides LLM classification on future trips.
Two-tier matching:
- User calibrations (source=user_calibration): standard threshold (0.85)
- Instacart seed data (source=instacart): higher threshold (0.93) to prevent
false matches from the 50K product database
"""
import chromadb
from costco_route.config import CHROMA_PATH, CHROMA_COLLECTION, EMBED_MODEL
from costco_route.llm_client import get_embedding
def _get_collection():
"""Get or create the ChromaDB collection."""
client = chromadb.PersistentClient(path=CHROMA_PATH)
return client.get_or_create_collection(
name=CHROMA_COLLECTION,
metadata={"hnsw:space": "cosine"},
)
def learn_item(item: str, zone_id: str, notes: str = ""):
"""Save a learned item location. Overwrites previous entry for same item.
Args:
item: Item name (e.g., "Kirkland Organic Milk")
zone_id: Zone ID (01-10)
notes: Optional notes (e.g., "back corner by bakery")
"""
col = _get_collection()
emb = get_embedding(item)
doc_id = item.lower().strip().replace(" ", "_")[:64]
col.upsert(
ids=[doc_id],
embeddings=[emb],
documents=[item],
metadatas=[{"zone": zone_id, "notes": notes, "source": "user_calibration"}],
)
def lookup_item(item: str, threshold: float = 0.85) -> dict | None:
"""Check if we have a learned location for an item.
Two-tier matching:
- User calibrations (source=user_calibration): use the standard threshold
- Instacart seed data (source=instacart): require higher similarity (0.93+)
to prevent false matches from the 50K product database
Args:
item: Item to look up
threshold: Cosine similarity threshold for user calibrations (0-1).
Returns:
Dict with zone, notes, item match — or None if no match found.
"""
col = _get_collection()
if col.count() == 0:
return None
emb = get_embedding(item)
results = col.query(
query_embeddings=[emb],
n_results=5,
include=["documents", "metadatas", "distances"],
)
if not results["ids"] or not results["ids"][0]:
return None
# ChromaDB cosine distance: 0 = identical, 2 = opposite
# Convert to similarity: 1 - distance/2
#
# Priority: user_calibration > instacart seed
# Instacart matches need higher similarity to avoid false positives
best_user_match = None
best_seed_match = None
for i, doc_id in enumerate(results["ids"][0]):
distance = results["distances"][0][i]
similarity = 1.0 - (distance / 2.0)
source = results["metadatas"][0][i].get("source", "unknown")
if source == "user_calibration" and similarity >= threshold:
if best_user_match is None or similarity > best_user_match["similarity"]:
best_user_match = {
"item": results["documents"][0][i],
"zone": results["metadatas"][0][i]["zone"],
"notes": results["metadatas"][0][i].get("notes", ""),
"similarity": round(similarity, 3),
"source": source,
}
elif source == "instacart" and similarity >= 0.93:
if best_seed_match is None or similarity > best_seed_match["similarity"]:
best_seed_match = {
"item": results["documents"][0][i],
"zone": results["metadatas"][0][i]["zone"],
"notes": results["metadatas"][0][i].get("notes", ""),
"similarity": round(similarity, 3),
"source": source,
}
# User calibrations always win
return best_user_match or best_seed_match
def lookup_items(items: list[str], threshold: float = 0.85) -> dict[str, dict]:
"""Batch lookup for multiple items.
Returns:
Dict of item → lookup result (only items with matches).
"""
results = {}
for item in items:
match = lookup_item(item, threshold)
if match:
results[item] = match
return results
def lookup_item_exact(item: str) -> dict | None:
"""Exact match lookup — only returns user calibrations, not seed data.
Use for the /learn command to verify existing calibrations.
"""
col = _get_collection()
doc_id = item.lower().strip().replace(" ", "_")[:64]
try:
result = col.get(ids=[doc_id], include=["documents", "metadatas"])
if result["ids"]:
return {
"item": result["documents"][0],
"zone": result["metadatas"][0]["zone"],
"notes": result["metadatas"][0].get("notes", ""),
}
except Exception:
pass
return None
def stats() -> dict:
"""Return stats about the learned item database."""
col = _get_collection()
total = col.count()
# Count by source
user_count = 0
seed_count = 0
try:
user_results = col.get(where={"source": "user_calibration"}, include=[])
user_count = len(user_results["ids"])
seed_results = col.get(where={"source": "instacart"}, include=[])
seed_count = len(seed_results["ids"])
except Exception:
pass
return {
"total_items": total,
"user_calibrations": user_count,
"instacart_seeded": seed_count,
"collection": CHROMA_COLLECTION,
"path": CHROMA_PATH,
}
def purge_old(months: int = 6) -> dict:
"""Remove entries older than N months.
Costco moves products around, so old calibrations may be stale.
"""
# ChromaDB doesn't have native date filtering on metadata easily,
# so for now this is a placeholder. Could add timestamp metadata.
return {"purged": 0, "remaining": stats()["total_items"]}