"""Shadow Validator — Validation metrics export and analysis. Provides daily exports of shadow_extractions for manual review, calculates precision/recall metrics, and generates validation reports. """ import json import csv import logging from pathlib import Path from datetime import datetime, timezone, timedelta from typing import Dict, Any, List, Optional from .shadow_database import ShadowDatabase class ShadowValidator: """Export and validate shadow mode extractions. Generates daily reports of extractions for Matt to review, calculates metrics once ground truth is provided. """ def __init__(self, db: ShadowDatabase, export_dir: Optional[Path] = None): self.db = db self.export_dir = Path(export_dir) if export_dir else Path.home() / ".icarus" / "shadow" / "exports" self.export_dir.mkdir(parents=True, exist_ok=True) # ----------------------------------------------------------------------- # Daily Export # ----------------------------------------------------------------------- def export_daily(self, date: Optional[str] = None) -> Path: """Export all messages and extractions for a given date. Returns path to exported file. """ if date is None: date = datetime.now(timezone.utc).strftime("%Y-%m-%d") # Get data for the date rows = self.db.export_for_review(date, date) if not rows: logging.info("[ShadowValidator] No data to export for %s", date) return None # Export to CSV for easy review export_path = self.export_dir / f"shadow_export_{date}.csv" with open(export_path, 'w', newline='', encoding='utf-8') as f: if rows: fieldnames = list(rows[0].keys()) writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) # Also export JSON for programmatic access json_path = self.export_dir / f"shadow_export_{date}.json" with open(json_path, 'w', encoding='utf-8') as f: json.dump(rows, f, indent=2, default=str) # Generate summary stats stats = self._generate_summary(rows, date) stats_path = self.export_dir / f"shadow_stats_{date}.json" with open(stats_path, 'w', encoding='utf-8') as f: json.dump(stats, f, indent=2) logging.info("[ShadowValidator] Exported %d rows to %s", len(rows), export_path) return export_path def _generate_summary(self, rows: List[Dict[str, Any]], date: str) -> Dict[str, Any]: """Generate summary statistics from exported rows.""" total_messages = len(rows) tripwire_fired = sum(1 for r in rows if r.get("tripwire_fired")) extractions = sum(1 for r in rows if r.get("extraction_type")) # Count by extraction type type_counts = {} for row in rows: ext_type = row.get("extraction_type") if ext_type: type_counts[ext_type] = type_counts.get(ext_type, 0) + 1 # Average confidence confidences = [r.get("extraction_confidence", 0) for r in rows if r.get("extraction_confidence")] avg_confidence = sum(confidences) / len(confidences) if confidences else 0 return { "date": date, "generated_at": datetime.now(timezone.utc).isoformat(), "total_messages": total_messages, "tripwire_fired": tripwire_fired, "extractions": extractions, "extraction_types": type_counts, "average_confidence": round(avg_confidence, 3), "tripwire_rate": round(tripwire_fired / total_messages, 3) if total_messages else 0, } # ----------------------------------------------------------------------- # Metrics Calculation (after manual validation) # ----------------------------------------------------------------------- def calculate_metrics(self, date: Optional[str] = None) -> Dict[str, Any]: """Calculate precision/recall metrics for validated extractions. Requires ground truth to be set via update_validation(). """ if date is None: date = datetime.now(timezone.utc).strftime("%Y-%m-%d") # Get validated extractions for the date conn = self.db._get_connection() try: rows = conn.execute( """ SELECT * FROM shadow_extractions WHERE DATE(created_at) = ? AND status IN ('validated', 'rejected') """, (date,), ).fetchall() finally: conn.close() if not rows: return { "date": date, "status": "no_validated_data", "message": "No validated extractions for this date", } # Calculate metrics true_positives = 0 false_positives = 0 false_negatives = 0 entity_correct = 0 entity_incorrect = 0 entity_missed = 0 for row in rows: # Ground truth comparison predicted_type = row["extraction_type"] actual_type = row.get("ground_truth_type") if actual_type: if predicted_type == actual_type and row["status"] == "validated": true_positives += 1 elif row["status"] == "rejected": false_positives += 1 elif predicted_type != actual_type: false_positives += 1 false_negatives += 1 # Entity accuracy extracted_who = json.loads(row["extracted_who"]) if row.get("extracted_who") else [] actual_who = json.loads(row["ground_truth_who"]) if row.get("ground_truth_who") else [] if actual_who: for person in extracted_who: if person in actual_who: entity_correct += 1 else: entity_incorrect += 1 for person in actual_who: if person not in extracted_who: entity_missed += 1 # Calculate precision/recall precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 entity_accuracy = entity_correct / (entity_correct + entity_incorrect) if (entity_correct + entity_incorrect) > 0 else 0 metrics = { "date": date, "calculated_at": datetime.now(timezone.utc).isoformat(), "sample_size": len(rows), "tripwire": { "true_positives": true_positives, "false_positives": false_positives, "false_negatives": false_negatives, "precision": round(precision, 3), "recall": round(recall, 3), "f1_score": round(f1, 3), }, "entity_recognition": { "correct": entity_correct, "incorrect": entity_incorrect, "missed": entity_missed, "accuracy": round(entity_accuracy, 3), }, } return metrics # ----------------------------------------------------------------------- # Review Queue # ----------------------------------------------------------------------- def get_review_queue(self, limit: int = 50) -> List[Dict[str, Any]]: """Get extractions awaiting manual validation.""" return self.db.get_unvalidated_extractions(limit) def export_review_queue(self, limit: int = 50) -> Path: """Export the review queue to CSV.""" queue = self.get_review_queue(limit) if not queue: logging.info("[ShadowValidator] No items in review queue") return None timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") export_path = self.export_dir / f"review_queue_{timestamp}.csv" with open(export_path, 'w', newline='', encoding='utf-8') as f: fieldnames = list(queue[0].keys()) writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(queue) return export_path # ----------------------------------------------------------------------- # Report Generation # ----------------------------------------------------------------------- def generate_weekly_report(self, days: int = 7) -> Dict[str, Any]: """Generate a weekly summary report.""" end_date = datetime.now(timezone.utc) start_date = end_date - timedelta(days=days) conn = self.db._get_connection() try: # Get daily stats rows = conn.execute( """ SELECT * FROM shadow_metrics WHERE date >= ? AND date <= ? ORDER BY date """, (start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")), ).fetchall() daily_stats = [dict(row) for row in rows] # Aggregate total_messages = sum(r.get("total_messages", 0) for r in daily_stats) total_tripwire = sum(r.get("tripwire_fired_count", 0) for r in daily_stats) total_extractions = sum(r.get("extractions_created", 0) for r in daily_stats) finally: conn.close() report = { "period": { "start": start_date.strftime("%Y-%m-%d"), "end": end_date.strftime("%Y-%m-%d"), "days": days, }, "summary": { "total_messages": total_messages, "tripwire_fired": total_tripwire, "extractions": total_extractions, "tripwire_rate": round(total_tripwire / total_messages, 3) if total_messages else 0, }, "daily": daily_stats, "generated_at": datetime.now(timezone.utc).isoformat(), } return report