"""Tests for Brain Query Interface.""" import pytest from datetime import datetime from brain.query import ( QueryPattern, detect_query_pattern, calculate_temporal_score, calculate_confidence ) from brain.embeddings import chunk_document class TestPatternDetection: """Test query pattern detection.""" def test_detect_logistical_pattern(self): """Should detect logistical queries.""" queries = [ "Is Friday a half-day?", "When is spring break?", "What time is dismissal today?", "Is practice cancelled this week?" ] for q in queries: result = detect_query_pattern(q) assert result == QueryPattern.LOGISTICAL, f"Expected LOGISTICAL for: {q}" def test_detect_entity_pattern(self): """Should detect entity queries.""" queries = [ "What size is our HVAC filter?", "How much is the dance tuition?" ] for q in queries: result = detect_query_pattern(q) assert result == QueryPattern.ENTITY, f"Expected ENTITY for: {q}" # "What time" is logistical (about current schedule) result = detect_query_pattern("What time is soccer practice?") assert result == QueryPattern.LOGISTICAL def test_detect_historical_pattern(self): """Should detect historical queries.""" queries = [ "What was the name of the roofer we used?", "When did we replace the water heater?", "What did we pay for the roof last year?" ] for q in queries: result = detect_query_pattern(q) assert result == QueryPattern.HISTORICAL, f"Expected HISTORICAL for: {q}" # Entity-like patterns in historical queries may be classified as entity result = detect_query_pattern("Who was Harper's teacher in 2023?") assert result in [QueryPattern.HISTORICAL, QueryPattern.ENTITY] def test_default_to_entity(self): """Should default to ENTITY for ambiguous queries.""" queries = [ "hello", "what", "test", "" # empty ] for q in queries: result = detect_query_pattern(q) assert result == QueryPattern.ENTITY, f"Expected ENTITY default for: {q}" class TestTemporalScoring: """Test temporal scoring calculations.""" def test_logistical_exponential_decay(self): """Logistical queries should use exponential decay.""" doc_date = datetime.now() score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email") assert score > 0.8 # Recent email should have high score # 7 days old (one half-life) old_date = datetime.now() - __import__('datetime').timedelta(days=7) score_old = calculate_temporal_score(old_date, QueryPattern.LOGISTICAL, "email") assert abs(score_old - 0.45) < 0.1 # ~0.5 * 0.9 def test_logistical_recency_cutoff(self): """Logistical queries should penalize old documents.""" very_old = datetime.now() - __import__('datetime').timedelta(days=40) score = calculate_temporal_score(very_old, QueryPattern.LOGISTICAL, "newsletter") # Should be penalized by 0.1 for being >30 days old assert score < 0.1 def test_entity_linear_decay(self): """Entity queries should use linear decay.""" doc_date = datetime.now() score = calculate_temporal_score(doc_date, QueryPattern.ENTITY, "invoice") assert score > 0.9 # Recent invoice # 180 days old (2x half-life) old_date = datetime.now() - __import__('datetime').timedelta(days=180) score_old = calculate_temporal_score(old_date, QueryPattern.ENTITY, "invoice") assert score_old == 0.0 # At cutoff def test_historical_no_decay(self): """Historical queries should not decay.""" old_date = datetime(2020, 1, 1) score = calculate_temporal_score(old_date, QueryPattern.HISTORICAL, "invoice") # Should be full score based on source weight only assert score == 1.0 # Invoice weight for historical def test_source_weights(self): """Different source types should have different weights.""" doc_date = datetime.now() # Logistical: calendar_event = 1.0, email = 0.9 cal_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "calendar_event") email_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email") assert cal_score > email_score class TestConfidenceScoring: """Test confidence scoring.""" def test_logistical_low_confidence(self): """Logistical queries with low scores should warn.""" conf, msg = calculate_confidence(0.2, QueryPattern.LOGISTICAL) assert conf == 0.2 assert "verify" in msg.lower() def test_logistical_high_confidence(self): """Logistical queries with high scores should be confident.""" conf, msg = calculate_confidence(0.8, QueryPattern.LOGISTICAL) assert conf == 0.9 assert "recent" in msg.lower() def test_entity_high_confidence(self): """Entity queries should be moderately confident.""" conf, msg = calculate_confidence(0.8, QueryPattern.ENTITY) assert conf == 0.85 def test_historical_always_confident(self): """Historical queries should always have decent confidence.""" conf, msg = calculate_confidence(0.1, QueryPattern.HISTORICAL) assert conf == 0.8 class TestDocumentChunking: """Test document chunking.""" def test_chunk_empty_text(self): """Should handle empty text.""" result = chunk_document("") assert result == [] def test_chunk_short_text(self): """Should return single chunk for short text.""" text = "Short text" result = chunk_document(text, chunk_size=100) assert len(result) == 1 assert result[0] == text def test_chunk_with_overlap(self): """Should create overlapping chunks.""" words = ["word" + str(i) for i in range(20)] text = " ".join(words) result = chunk_document(text, chunk_size=10, overlap=2) assert len(result) > 1 # Each chunk after first should overlap with previous for i in range(1, len(result)): prev_words = set(result[i-1].split()) curr_words = set(result[i].split()) overlap = len(prev_words & curr_words) assert overlap > 0, f"No overlap between chunk {i-1} and {i}" class TestIntegration: """Integration tests requiring ChromaDB.""" @pytest.mark.skip(reason="Requires running ChromaDB and Ollama") def test_end_to_end_query(self): """Test full query pipeline.""" from brain.embeddings import get_embedding from brain.store import search_documents query = "What size is the HVAC filter?" pattern = QueryPattern.ENTITY embedding = get_embedding(query) results = search_documents(embedding, pattern, n_results=3) assert len(results) > 0 assert "filter" in results[0]["text"].lower() if __name__ == "__main__": pytest.main([__file__, "-v"])