📄 test_brain.py 7,463 bytes Apr 30, 2026 📋 Raw

"""Tests for Brain Query Interface."""

import pytest
from datetime import datetime
from brain.query import (
QueryPattern,
detect_query_pattern,
calculate_temporal_score,
calculate_confidence
)
from brain.embeddings import chunk_document

class TestPatternDetection:
"""Test query pattern detection."""

def test_detect_logistical_pattern(self):
    """Should detect logistical queries."""
    queries = [
        "Is Friday a half-day?",
        "When is spring break?",
        "What time is dismissal today?",
        "Is practice cancelled this week?"
    ]
    for q in queries:
        result = detect_query_pattern(q)
        assert result == QueryPattern.LOGISTICAL, f"Expected LOGISTICAL for: {q}"

def test_detect_entity_pattern(self):
    """Should detect entity queries."""
    queries = [
        "What size is our HVAC filter?",
        "How much is the dance tuition?"
    ]
    for q in queries:
        result = detect_query_pattern(q)
        assert result == QueryPattern.ENTITY, f"Expected ENTITY for: {q}"

    # "What time" is logistical (about current schedule)
    result = detect_query_pattern("What time is soccer practice?")
    assert result == QueryPattern.LOGISTICAL

def test_detect_historical_pattern(self):
    """Should detect historical queries."""
    queries = [
        "What was the name of the roofer we used?",
        "When did we replace the water heater?",
        "What did we pay for the roof last year?"
    ]
    for q in queries:
        result = detect_query_pattern(q)
        assert result == QueryPattern.HISTORICAL, f"Expected HISTORICAL for: {q}"

    # Entity-like patterns in historical queries may be classified as entity
    result = detect_query_pattern("Who was Harper's teacher in 2023?")
    assert result in [QueryPattern.HISTORICAL, QueryPattern.ENTITY]

def test_default_to_entity(self):
    """Should default to ENTITY for ambiguous queries."""
    queries = [
        "hello",
        "what",
        "test",
        ""  # empty
    ]
    for q in queries:
        result = detect_query_pattern(q)
        assert result == QueryPattern.ENTITY, f"Expected ENTITY default for: {q}"

class TestTemporalScoring:
"""Test temporal scoring calculations."""

def test_logistical_exponential_decay(self):
    """Logistical queries should use exponential decay."""
    doc_date = datetime.now()
    score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email")
    assert score > 0.8  # Recent email should have high score

    # 7 days old (one half-life)
    old_date = datetime.now() - __import__('datetime').timedelta(days=7)
    score_old = calculate_temporal_score(old_date, QueryPattern.LOGISTICAL, "email")
    assert abs(score_old - 0.45) < 0.1  # ~0.5 * 0.9

def test_logistical_recency_cutoff(self):
    """Logistical queries should penalize old documents."""
    very_old = datetime.now() - __import__('datetime').timedelta(days=40)
    score = calculate_temporal_score(very_old, QueryPattern.LOGISTICAL, "newsletter")
    # Should be penalized by 0.1 for being >30 days old
    assert score < 0.1

def test_entity_linear_decay(self):
    """Entity queries should use linear decay."""
    doc_date = datetime.now()
    score = calculate_temporal_score(doc_date, QueryPattern.ENTITY, "invoice")
    assert score > 0.9  # Recent invoice

    # 180 days old (2x half-life)
    old_date = datetime.now() - __import__('datetime').timedelta(days=180)
    score_old = calculate_temporal_score(old_date, QueryPattern.ENTITY, "invoice")
    assert score_old == 0.0  # At cutoff

def test_historical_no_decay(self):
    """Historical queries should not decay."""
    old_date = datetime(2020, 1, 1)
    score = calculate_temporal_score(old_date, QueryPattern.HISTORICAL, "invoice")
    # Should be full score based on source weight only
    assert score == 1.0  # Invoice weight for historical

def test_source_weights(self):
    """Different source types should have different weights."""
    doc_date = datetime.now()

    # Logistical: calendar_event = 1.0, email = 0.9
    cal_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "calendar_event")
    email_score = calculate_temporal_score(doc_date, QueryPattern.LOGISTICAL, "email")
    assert cal_score > email_score

class TestConfidenceScoring:
"""Test confidence scoring."""

def test_logistical_low_confidence(self):
    """Logistical queries with low scores should warn."""
    conf, msg = calculate_confidence(0.2, QueryPattern.LOGISTICAL)
    assert conf == 0.2
    assert "verify" in msg.lower()

def test_logistical_high_confidence(self):
    """Logistical queries with high scores should be confident."""
    conf, msg = calculate_confidence(0.8, QueryPattern.LOGISTICAL)
    assert conf == 0.9
    assert "recent" in msg.lower()

def test_entity_high_confidence(self):
    """Entity queries should be moderately confident."""
    conf, msg = calculate_confidence(0.8, QueryPattern.ENTITY)
    assert conf == 0.85

def test_historical_always_confident(self):
    """Historical queries should always have decent confidence."""
    conf, msg = calculate_confidence(0.1, QueryPattern.HISTORICAL)
    assert conf == 0.8

class TestDocumentChunking:
"""Test document chunking."""

def test_chunk_empty_text(self):
    """Should handle empty text."""
    result = chunk_document("")
    assert result == []

def test_chunk_short_text(self):
    """Should return single chunk for short text."""
    text = "Short text"
    result = chunk_document(text, chunk_size=100)
    assert len(result) == 1
    assert result[0] == text

def test_chunk_with_overlap(self):
    """Should create overlapping chunks."""
    words = ["word" + str(i) for i in range(20)]
    text = " ".join(words)
    result = chunk_document(text, chunk_size=10, overlap=2)

    assert len(result) > 1
    # Each chunk after first should overlap with previous
    for i in range(1, len(result)):
        prev_words = set(result[i-1].split())
        curr_words = set(result[i].split())
        overlap = len(prev_words & curr_words)
        assert overlap > 0, f"No overlap between chunk {i-1} and {i}"

class TestIntegration:
"""Integration tests requiring ChromaDB."""

@pytest.mark.skip(reason="Requires running ChromaDB and Ollama")
def test_end_to_end_query(self):
    """Test full query pipeline."""
    from brain.embeddings import get_embedding
    from brain.store import search_documents

    query = "What size is the HVAC filter?"
    pattern = QueryPattern.ENTITY

    embedding = get_embedding(query)
    results = search_documents(embedding, pattern, n_results=3)

    assert len(results) > 0
    assert "filter" in results[0]["text"].lower()

if name == "main":
pytest.main([file, "-v"])