"""Shadow Validator — Validation metrics export and analysis.
Provides daily exports of shadow_extractions for manual review,
calculates precision/recall metrics, and generates validation reports.
"""
import json
import csv
import logging
from pathlib import Path
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, List, Optional
from .shadow_database import ShadowDatabase
class ShadowValidator:
"""Export and validate shadow mode extractions.
Generates daily reports of extractions for Matt to review,
calculates metrics once ground truth is provided.
"""
def __init__(self, db: ShadowDatabase, export_dir: Optional[Path] = None):
self.db = db
self.export_dir = Path(export_dir) if export_dir else Path.home() / ".icarus" / "shadow" / "exports"
self.export_dir.mkdir(parents=True, exist_ok=True)
# -----------------------------------------------------------------------
# Daily Export
# -----------------------------------------------------------------------
def export_daily(self, date: Optional[str] = None) -> Path:
"""Export all messages and extractions for a given date.
Returns path to exported file.
"""
if date is None:
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
# Get data for the date
rows = self.db.export_for_review(date, date)
if not rows:
logging.info("[ShadowValidator] No data to export for %s", date)
return None
# Export to CSV for easy review
export_path = self.export_dir / f"shadow_export_{date}.csv"
with open(export_path, 'w', newline='', encoding='utf-8') as f:
if rows:
fieldnames = list(rows[0].keys())
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
# Also export JSON for programmatic access
json_path = self.export_dir / f"shadow_export_{date}.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(rows, f, indent=2, default=str)
# Generate summary stats
stats = self._generate_summary(rows, date)
stats_path = self.export_dir / f"shadow_stats_{date}.json"
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump(stats, f, indent=2)
logging.info("[ShadowValidator] Exported %d rows to %s", len(rows), export_path)
return export_path
def _generate_summary(self, rows: List[Dict[str, Any]], date: str) -> Dict[str, Any]:
"""Generate summary statistics from exported rows."""
total_messages = len(rows)
tripwire_fired = sum(1 for r in rows if r.get("tripwire_fired"))
extractions = sum(1 for r in rows if r.get("extraction_type"))
# Count by extraction type
type_counts = {}
for row in rows:
ext_type = row.get("extraction_type")
if ext_type:
type_counts[ext_type] = type_counts.get(ext_type, 0) + 1
# Average confidence
confidences = [r.get("extraction_confidence", 0) for r in rows if r.get("extraction_confidence")]
avg_confidence = sum(confidences) / len(confidences) if confidences else 0
return {
"date": date,
"generated_at": datetime.now(timezone.utc).isoformat(),
"total_messages": total_messages,
"tripwire_fired": tripwire_fired,
"extractions": extractions,
"extraction_types": type_counts,
"average_confidence": round(avg_confidence, 3),
"tripwire_rate": round(tripwire_fired / total_messages, 3) if total_messages else 0,
}
# -----------------------------------------------------------------------
# Metrics Calculation (after manual validation)
# -----------------------------------------------------------------------
def calculate_metrics(self, date: Optional[str] = None) -> Dict[str, Any]:
"""Calculate precision/recall metrics for validated extractions.
Requires ground truth to be set via update_validation().
"""
if date is None:
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
# Get validated extractions for the date
conn = self.db._get_connection()
try:
rows = conn.execute(
"""
SELECT * FROM shadow_extractions
WHERE DATE(created_at) = ? AND status IN ('validated', 'rejected')
""",
(date,),
).fetchall()
finally:
conn.close()
if not rows:
return {
"date": date,
"status": "no_validated_data",
"message": "No validated extractions for this date",
}
# Calculate metrics
true_positives = 0
false_positives = 0
false_negatives = 0
entity_correct = 0
entity_incorrect = 0
entity_missed = 0
for row in rows:
# Ground truth comparison
predicted_type = row["extraction_type"]
actual_type = row.get("ground_truth_type")
if actual_type:
if predicted_type == actual_type and row["status"] == "validated":
true_positives += 1
elif row["status"] == "rejected":
false_positives += 1
elif predicted_type != actual_type:
false_positives += 1
false_negatives += 1
# Entity accuracy
extracted_who = json.loads(row["extracted_who"]) if row.get("extracted_who") else []
actual_who = json.loads(row["ground_truth_who"]) if row.get("ground_truth_who") else []
if actual_who:
for person in extracted_who:
if person in actual_who:
entity_correct += 1
else:
entity_incorrect += 1
for person in actual_who:
if person not in extracted_who:
entity_missed += 1
# Calculate precision/recall
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
entity_accuracy = entity_correct / (entity_correct + entity_incorrect) if (entity_correct + entity_incorrect) > 0 else 0
metrics = {
"date": date,
"calculated_at": datetime.now(timezone.utc).isoformat(),
"sample_size": len(rows),
"tripwire": {
"true_positives": true_positives,
"false_positives": false_positives,
"false_negatives": false_negatives,
"precision": round(precision, 3),
"recall": round(recall, 3),
"f1_score": round(f1, 3),
},
"entity_recognition": {
"correct": entity_correct,
"incorrect": entity_incorrect,
"missed": entity_missed,
"accuracy": round(entity_accuracy, 3),
},
}
return metrics
# -----------------------------------------------------------------------
# Review Queue
# -----------------------------------------------------------------------
def get_review_queue(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get extractions awaiting manual validation."""
return self.db.get_unvalidated_extractions(limit)
def export_review_queue(self, limit: int = 50) -> Path:
"""Export the review queue to CSV."""
queue = self.get_review_queue(limit)
if not queue:
logging.info("[ShadowValidator] No items in review queue")
return None
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
export_path = self.export_dir / f"review_queue_{timestamp}.csv"
with open(export_path, 'w', newline='', encoding='utf-8') as f:
fieldnames = list(queue[0].keys())
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(queue)
return export_path
# -----------------------------------------------------------------------
# Report Generation
# -----------------------------------------------------------------------
def generate_weekly_report(self, days: int = 7) -> Dict[str, Any]:
"""Generate a weekly summary report."""
end_date = datetime.now(timezone.utc)
start_date = end_date - timedelta(days=days)
conn = self.db._get_connection()
try:
# Get daily stats
rows = conn.execute(
"""
SELECT * FROM shadow_metrics
WHERE date >= ? AND date <= ?
ORDER BY date
""",
(start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")),
).fetchall()
daily_stats = [dict(row) for row in rows]
# Aggregate
total_messages = sum(r.get("total_messages", 0) for r in daily_stats)
total_tripwire = sum(r.get("tripwire_fired_count", 0) for r in daily_stats)
total_extractions = sum(r.get("extractions_created", 0) for r in daily_stats)
finally:
conn.close()
report = {
"period": {
"start": start_date.strftime("%Y-%m-%d"),
"end": end_date.strftime("%Y-%m-%d"),
"days": days,
},
"summary": {
"total_messages": total_messages,
"tripwire_fired": total_tripwire,
"extractions": total_extractions,
"tripwire_rate": round(total_tripwire / total_messages, 3) if total_messages else 0,
},
"daily": daily_stats,
"generated_at": datetime.now(timezone.utc).isoformat(),
}
return report