📄 llm_client.py 7,338 bytes Apr 19, 2026 📋 Raw

"""LLM client for Costco Route Optimizer.

Calls local Ollama endpoints for item extraction and zone classification.
"""

import json
import requests

from costco_route.config import LLM_URL, EMBED_URL, LLM_MODEL, EMBED_MODEL

def _call_llm(prompt: str, timeout: int = 60) -> str:
"""Send a prompt to Ollama and return the raw text response."""
payload = {
"model": LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"options": {"temperature": 0.1, "num_predict": 1024},
}
resp = requests.post(LLM_URL, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()["message"]["content"].strip()

def _get_embedding(text: str, timeout: int = 30) -> list[float]:
"""Get an embedding vector from Ollama."""
payload = {"model": EMBED_MODEL, "prompt": text}
resp = requests.post(EMBED_URL, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()["embedding"]

def extract_items(raw_input: str) -> list[str]:
"""Parse a stream-of-consciousness grocery list into individual items.

Uses the LLM to normalize and separate compound phrases.
Falls back to simple comma/space splitting on LLM failure.
"""
from costco_route.config import EXTRACT_PROMPT

prompt = EXTRACT_PROMPT.format(input=raw_input)
try:
    raw = _call_llm(prompt)
    # Strip markdown code fences if present
    raw = raw.strip().removeprefix("```json").removeprefix("```").removesuffix("```").strip()
    items = json.loads(raw)
    if isinstance(items, list):
        return [str(i).strip() for i in items if str(i).strip()]
except (json.JSONDecodeError, requests.RequestException, KeyError):
    pass

# Fallback: simple split on commas, newlines, and common separators
import re
parts = re.split(r"[,\n]+", raw_input)
return [p.strip() for p in parts if p.strip()]

def classify_items(items: list[str]) -> dict[str, list[str]]:
"""Classify grocery items into Costco warehouse zones.

Returns a dict of zone_id  list of items.
Items that can't be classified go to zone "04" (Pantry — the catch-all center).
"""
from costco_route.config import ZONE_CLASSIFY_PROMPT

items_str = "\n".join(f"- {item}" for item in items)
prompt = ZONE_CLASSIFY_PROMPT.format(items=items_str)

try:
    raw = _call_llm(prompt)
    raw = raw.strip().removeprefix("```json").removeprefix("```").removesuffix("```").strip()
    classified = json.loads(raw)

    # Validate and normalize
    result = {}
    for zone_id, zone_items in classified.items():
        zone_id = str(zone_id).strip()
        if zone_id in ("01","02","03","04","05","06","07","08","09","10"):
            if isinstance(zone_items, list):
                result[zone_id] = [str(i).strip() for i in zone_items if str(i).strip()]

    # Hard sanity checks: override obvious LLM misclassifications
    result = _apply_sanity_checks(result)
    return result

except (json.JSONDecodeError, requests.RequestException, KeyError):
    # On failure, put everything in pantry (safest fallback)
    return {"04": items}

def _apply_sanity_checks(classified: dict[str, list[str]]) -> dict[str, list[str]]:
"""Override obviously wrong zone assignments.

The LLM sometimes hallucinates zones for common items.
These rules catch the most frequent mistakes.
"""
# Items that must NOT be in Zone 10 (checkout/food court)
NEVER_ZONE_10 = [
    "toilet paper", "paper towels", "detergent", "trash bags",
    "laundry", "cleaning", "dish soap",
]
# Items that belong in Zone 06 (dairy/cold) not Zone 04 (pantry)
DAIRY_KEYWORDS = [
    "cheese", "cheddar", "mozzarella", "yogurt", "butter", "sour cream",
    "cream cheese", "half and half", "heavy cream",
]
# Items that belong in Zone 07 (fresh/produce) not Zone 02 (seasonal)
PRODUCE_KEYWORDS = [
    "fruit", "tomato", "lettuce", "romaine", "onion", "potato",
    "avocado", "banana", "apple", "berry", "berries", "vegetable",
    "pepper", "carrot", "broccoli", "spinach", "salad", "corn",
    "celery", "cucumber", "mushroom", "zucchini", "squash", "kale",
]
# Items that belong in Zone 04 (pantry) not elsewhere
PANTRY_KEYWORDS = {
    "stock": "04", "broth": "04", "soup": "04", "canned": "04",
}
# Bacon → cold room (Zone 06), not meat counter (Zone 07)
BACON_KEYWORDS = ["bacon"]

# Fix Zone 10 misclassifications → move to Zone 08
if "10" in classified:
    misplaced = []
    remaining = []
    for item in classified["10"]:
        item_lower = item.lower()
        if any(kw in item_lower for kw in NEVER_ZONE_10):
            misplaced.append(item)
        else:
            remaining.append(item)
    if misplaced:
        classified.setdefault("08", []).extend(misplaced)
    if remaining:
        classified["10"] = remaining
    else:
        del classified["10"]

# Fix Zone 04 dairy items → move to Zone 06
if "04" in classified:
    misplaced = []
    remaining = []
    for item in classified["04"]:
        item_lower = item.lower()
        if any(kw in item_lower for kw in DAIRY_KEYWORDS):
            misplaced.append(item)
        else:
            remaining.append(item)
    if misplaced:
        classified.setdefault("06", []).extend(misplaced)
    if remaining:
        classified["04"] = remaining
    else:
        del classified["04"]

# Fix Zone 02 produce items → move to Zone 07
if "02" in classified:
    misplaced = []
    remaining = []
    for item in classified["02"]:
        item_lower = item.lower()
        if any(kw in item_lower for kw in PRODUCE_KEYWORDS):
            misplaced.append(item)
        else:
            remaining.append(item)
    if misplaced:
        classified.setdefault("07", []).extend(misplaced)
    if remaining:
        classified["02"] = remaining
    else:
        del classified["02"]

# Fix bacon → Zone 06 (cold room), not Zone 07 (meat counter)
for zone in list(classified.keys()):
    bacon_items = [item for item in classified[zone] if "bacon" in item.lower()]
    if bacon_items and zone != "06":
        classified[zone] = [item for item in classified[zone] if "bacon" not in item.lower()]
        classified.setdefault("06", []).extend(bacon_items)
        if not classified[zone]:
            del classified[zone]

# Fix stock/broth → Zone 04 (pantry)
for zone in list(classified.keys()):
    stock_items = [item for item in classified[zone] if any(kw in item.lower() for kw in PANTRY_KEYWORDS)]
    if stock_items and zone != "04":
        classified[zone] = [item for item in classified[zone] if not any(kw in item.lower() for kw in PANTRY_KEYWORDS)]
        classified.setdefault("04", []).extend(stock_items)
        if not classified[zone]:
            del classified[zone]

return classified

def get_embedding(text: str) -> list[float]:
"""Public embedding interface."""
return _get_embedding(text)