"""Tests for Brain Query Interface."""
import pytest
from datetime import datetime
from brain.query import (
QueryPattern,
detect_query_pattern,
calculate_temporal_score,
calculate_confidence
)
from brain.embeddings import chunk_document
class TestPatternDetection:
"""Test query pattern detection."""
def test_detect_logistical_pattern(self):
"""Should detect logistical queries."""
queries = [
"Is Friday a half-day?",
"When is spring break?",
"What time is dismissal today?",
"Is practice cancelled this week?"
]
for q in queries:
result = detect_query_pattern(q)
assert result == QueryPattern.LOGISTICAL, f"Expected LOGISTICAL for: {q}"
def test_detect_entity_pattern(self):
"""Should detect entity queries."""
queries = [
"What size is our HVAC filter?",
"How much is the dance tuition?"
]
for q in queries:
result = detect_query_pattern(q)
assert result == QueryPattern.ENTITY, f"Expected ENTITY for: {q}"
# "What time" is logistical (about current schedule)
result = detect_query_pattern("What time is soccer practice?")
assert result == QueryPattern.LOGISTICAL
def test_detect_historical_pattern(self):
"""Should detect historical queries."""
queries = [
"What was the name of the roofer we used?",
"When did we replace the water heater?",
"What did we pay for the roof last year?"
]
for q in queries:
result = detect_query_pattern(q)
assert result == QueryPattern.HISTORICAL, f"Expected HISTORICAL for: {q}"
# Entity-like patterns in historical queries may be classified as entity
result = detect_query_pattern("Who was Harper's teacher in 2023?")
assert result in [QueryPattern.HISTORICAL, QueryPattern.ENTITY]
def test_default_to_entity(self):
"""Should default to ENTITY for ambiguous queries."""
queries = [
"hello",
"what",
"test",
"" # empty
]
for q in queries:
result = detect_query_pattern(q)
assert result == QueryPattern.ENTITY, f"Expected ENTITY default for: {q}"
class TestTemporalScoring:
"""Test temporal scoring calculations."""
def test_logistical_exponential_decay(self):
"""Logistical queries should use exponential decay."""
doc_date = datetime.now()
score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email")
assert score > 0.8 # Recent email should have high score
# 7 days old (one half-life)
old_date = datetime.now() - __import__('datetime').timedelta(days=7)
score_old = calculate_temporal_score(old_date, QueryPattern.LOGISTICAL, "email")
assert abs(score_old - 0.45) < 0.1 # ~0.5 * 0.9
def test_logistical_recency_cutoff(self):
"""Logistical queries should penalize old documents."""
very_old = datetime.now() - __import__('datetime').timedelta(days=40)
score = calculate_temporal_score(very_old, QueryPattern.LOGISTICAL, "newsletter")
# Should be penalized by 0.1 for being >30 days old
assert score < 0.1
def test_entity_linear_decay(self):
"""Entity queries should use linear decay."""
doc_date = datetime.now()
score = calculate_temporal_score(doc_date, QueryPattern.ENTITY, "invoice")
assert score > 0.9 # Recent invoice
# 180 days old (2x half-life)
old_date = datetime.now() - __import__('datetime').timedelta(days=180)
score_old = calculate_temporal_score(old_date, QueryPattern.ENTITY, "invoice")
assert score_old == 0.0 # At cutoff
def test_historical_no_decay(self):
"""Historical queries should not decay."""
old_date = datetime(2020, 1, 1)
score = calculate_temporal_score(old_date, QueryPattern.HISTORICAL, "invoice")
# Should be full score based on source weight only
assert score == 1.0 # Invoice weight for historical
def test_source_weights(self):
"""Different source types should have different weights."""
doc_date = datetime.now()
# Logistical: calendar_event = 1.0, email = 0.9
cal_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "calendar_event")
email_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email")
assert cal_score > email_score
class TestConfidenceScoring:
"""Test confidence scoring."""
def test_logistical_low_confidence(self):
"""Logistical queries with low scores should warn."""
conf, msg = calculate_confidence(0.2, QueryPattern.LOGISTICAL)
assert conf == 0.2
assert "verify" in msg.lower()
def test_logistical_high_confidence(self):
"""Logistical queries with high scores should be confident."""
conf, msg = calculate_confidence(0.8, QueryPattern.LOGISTICAL)
assert conf == 0.9
assert "recent" in msg.lower()
def test_entity_high_confidence(self):
"""Entity queries should be moderately confident."""
conf, msg = calculate_confidence(0.8, QueryPattern.ENTITY)
assert conf == 0.85
def test_historical_always_confident(self):
"""Historical queries should always have decent confidence."""
conf, msg = calculate_confidence(0.1, QueryPattern.HISTORICAL)
assert conf == 0.8
class TestDocumentChunking:
"""Test document chunking."""
def test_chunk_empty_text(self):
"""Should handle empty text."""
result = chunk_document("")
assert result == []
def test_chunk_short_text(self):
"""Should return single chunk for short text."""
text = "Short text"
result = chunk_document(text, chunk_size=100)
assert len(result) == 1
assert result[0] == text
def test_chunk_with_overlap(self):
"""Should create overlapping chunks."""
words = ["word" + str(i) for i in range(20)]
text = " ".join(words)
result = chunk_document(text, chunk_size=10, overlap=2)
assert len(result) > 1
# Each chunk after first should overlap with previous
for i in range(1, len(result)):
prev_words = set(result[i-1].split())
curr_words = set(result[i].split())
overlap = len(prev_words & curr_words)
assert overlap > 0, f"No overlap between chunk {i-1} and {i}"
class TestIntegration:
"""Integration tests requiring ChromaDB."""
@pytest.mark.skip(reason="Requires running ChromaDB and Ollama")
def test_end_to_end_query(self):
"""Test full query pipeline."""
from brain.embeddings import get_embedding
from brain.store import search_documents
query = "What size is the HVAC filter?"
pattern = QueryPattern.ENTITY
embedding = get_embedding(query)
results = search_documents(embedding, pattern, n_results=3)
assert len(results) > 0
assert "filter" in results[0]["text"].lower()
if name == "main":
pytest.main([file, "-v"])