📄 graph.py 27,080 bytes Monday 19:49 📋 Raw

"""
Knowledge Graph — SQLite + FAISS entity store for structured memory.

Architecture:
- SQLite: persistent entity/relation store with controlled vocabulary
- FAISS: vector index for semantic search over entity descriptions
- Ollama (remote): text embeddings via Qwen3-Embedding-0.6B or nomic-embed-text
- All write paths validate, normalize, and deduplicate on insert
"""

import json
import os
import uuid
import re
import time
import logging
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Optional

import numpy as np

logger = logging.getLogger("knowledge_graph")

──────────────────────────────────────────────

Paths

──────────────────────────────────────────────

BASE_DIR = os.path.dirname(os.path.abspath(file))
DB_PATH = os.path.join(BASE_DIR, "entities.db")
VECTORS_DIR = os.path.join(BASE_DIR, "vectors")
os.makedirs(VECTORS_DIR, exist_ok=True)

INDEX_PATH = os.path.join(VECTORS_DIR, "entity_index.faiss")
ID_MAP_PATH = os.path.join(VECTORS_DIR, "id_map.json")

──────────────────────────────────────────────

Controlled relation vocabulary

──────────────────────────────────────────────

VALID_RELATION_TYPES = frozenset({
# family
"father_of", "mother_of", "parent_of", "child_of",
"sibling_of", "spouse_of",
# social
"friend_of", "colleague_of", "manages", "reports_to",
# spatial
"lives_in", "works_at", "located_in",
# temporal
"happens_on", "precedes", "follows",
# project
"owns", "builds_on", "depends_on", "replaces",
# knowledge
"extracted_from", "contradicts", "extends", "cites",
})

VAGUE_RELATION_TYPES = frozenset({
"related_to", "associated_with", "connected_to", "linked_to",
"has_relation", "involves", "correlates_with",
})

VALID_ENTITY_TYPES = frozenset({"person", "place", "project", "event", "fact"})
VALID_SOURCES = frozenset({"live", "extraction", "dream"})

──────────────────────────────────────────────

SQLite schema & helpers

──────────────────────────────────────────────

SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
type TEXT NOT NULL CHECK(type IN ('person','place','project','event','fact')),
subject TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
aliases TEXT NOT NULL DEFAULT '[]',
tags TEXT NOT NULL DEFAULT '[]',
confidence REAL NOT NULL DEFAULT 1.0 CHECK(confidence >= 0 AND confidence <= 1),
source TEXT NOT NULL DEFAULT 'live' CHECK(source IN ('live','extraction','dream')),
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS relations (
id TEXT PRIMARY KEY,
source_entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE,
target_entity_id TEXT NOT NULL REFERENCES entities(id) ON DELETE CASCADE,
relation_type TEXT NOT NULL,
confidence REAL NOT NULL DEFAULT 1.0 CHECK(confidence >= 0 AND confidence <= 1),
source TEXT NOT NULL DEFAULT 'live' CHECK(source IN ('live','extraction','dream')),
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
UNIQUE(source_entity_id, target_entity_id, relation_type)
);

CREATE INDEX IF NOT EXISTS idx_relations_source ON relations(source_entity_id);
CREATE INDEX IF NOT EXISTS idx_relations_target ON relations(target_entity_id);
CREATE INDEX IF NOT EXISTS idx_relations_type ON relations(relation_type);
CREATE INDEX IF NOT EXISTS idx_entities_subject ON entities(subject);
CREATE INDEX IF NOT EXISTS idx_entities_type ON entities(type);
"""

def _get_db():
"""Return a thread-safe connection. Use context manager for writes."""
import sqlite3
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn

_db_conn = None

def get_connection():
global _db_conn
if _db_conn is None:
_db_conn = _get_db()
_db_conn.executescript(SCHEMA_SQL)
_db_conn.commit()
return _db_conn

@contextmanager
def transaction():
conn = get_connection()
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise

──────────────────────────────────────────────

Embedding helpers (Ollama remote on Gaming PC)

──────────────────────────────────────────────

EMBEDDING_MODEL = os.environ.get("KNOWLEDGE_EMBEDDING_MODEL", "nomic-embed-text")
OLLAMA_HOST = os.environ.get("KNOWLEDGE_OLLAMA_HOST", "http://matt-pc:11434")
_embedding_cache = {}

def _get_embedding_dim() -> int:
"""Detect embedding dimension by making a test call."""
try:
test_vec = _call_ollama_embed("test dimension")
return len(test_vec)
except Exception:
return 768 # fallback to nomic-embed-text default

def _call_ollama_embed(text: str) -> list[float]:
"""Call Ollama embedding API on the remote Gaming PC."""
import httpx
try:
resp = httpx.post(
f"{OLLAMA_HOST}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text},
timeout=30,
)
resp.raise_for_status()
return resp.json()["embedding"]
except Exception as e:
logger.warning(f"Ollama embedding failed, using zero vector: {e}")
# Return a zero vector of reasonable dimension as fallback
return [0.0] * 768

def embed_text(text: str) -> np.ndarray:
"""Get embedding vector for text, with caching."""
if not text or not text.strip():
dim = _get_embedding_dim()
return np.zeros(dim, dtype=np.float32)

cache_key = text.strip().lower()
if cache_key in _embedding_cache:
    return _embedding_cache[cache_key]

vec = np.array(_call_ollama_embed(text), dtype=np.float32)
# Normalize to unit length for cosine similarity
norm = np.linalg.norm(vec)
if norm > 0:
    vec = vec / norm
_embedding_cache[cache_key] = vec
return vec

──────────────────────────────────────────────

Subject normalization for dedup

──────────────────────────────────────────────

def normalize_subject(subject: str) -> str:
"""Normalize a subject string for deterministic comparison."""
s = subject.strip().lower()
# Collapse whitespace
s = re.sub(r'\s+', ' ', s)
# Remove punctuation (keep letters, digits, spaces, hyphens, apostrophes)
s = re.sub(r'[^\w\s\'-]', '', s)
# Collapse again after removal
s = re.sub(r'\s+', ' ', s).strip()
return s

──────────────────────────────────────────────

Validation helpers

──────────────────────────────────────────────

def _now() -> str:
return datetime.now(timezone.utc).isoformat()

def _uuid() -> str:
return str(uuid.uuid4())

class ValidationError(ValueError):
pass

def validate_entity_type(type_: str):
if type_ not in VALID_ENTITY_TYPES:
raise ValidationError(
f"Invalid entity type '{type_}'. Must be one of: {', '.join(sorted(VALID_ENTITY_TYPES))}"
)

def validate_source(source: str):
if source not in VALID_SOURCES:
raise ValidationError(
f"Invalid source '{source}'. Must be one of: {', '.join(sorted(VALID_SOURCES))}"
)

def validate_relation_type(relation_type: str):
if relation_type in VAGUE_RELATION_TYPES:
raise ValidationError(
f"Vague relation type '{relation_type}' is rejected. "
f"Use a specific type from: {', '.join(sorted(VALID_RELATION_TYPES))}"
)
if relation_type not in VALID_RELATION_TYPES:
raise ValidationError(
f"Invalid relation type '{relation_type}'. "
f"Must be one of: {', '.join(sorted(VALID_RELATION_TYPES))}"
)

def validate_confidence(confidence: float):
if not (0 <= confidence <= 1):
raise ValidationError(f"Confidence must be between 0 and 1, got {confidence}")

──────────────────────────────────────────────

FAISS index management

──────────────────────────────────────────────

_index = None
_id_map: dict[str, int] = {} # entity_id -> faiss index position
_rev_id_map: dict[int, str] = {} # faiss index position -> entity_id

def _get_faiss_index():
"""Lazy-load or create the FAISS index."""
global _index, _id_map, _rev_id_map
import faiss

if _index is not None:
    return _index

if os.path.exists(INDEX_PATH) and os.path.exists(ID_MAP_PATH):
    try:
        _index = faiss.read_index(INDEX_PATH)
        with open(ID_MAP_PATH, "r") as f:
            _id_map = json.load(f)
        _rev_id_map = {v: k for k, v in _id_map.items()}
        logger.info(f"Loaded FAISS index with {_index.ntotal} vectors")
        return _index
    except Exception as e:
        logger.warning(f"Failed to load FAISS index, creating new: {e}")

# Create new index — use inner product (cosine with normalized vectors)
dim = _get_embedding_dim()
_index = faiss.IndexFlatIP(dim)
_id_map = {}
_rev_id_map = {}
return _index

def _rebuild_index():
"""Rebuild the entire FAISS index from scratch from the DB."""
global _index, _id_map, _rev_id_map
import faiss

conn = get_connection()
rows = conn.execute(
    "SELECT id, subject, description FROM entities"
).fetchall()

dim = _get_embedding_dim()
_index = faiss.IndexFlatIP(dim)
_id_map = {}
_rev_id_map = {}

embeddings = []
ids = []
for row in rows:
    text = f"{row['subject']}: {row['description']}"
    vec = embed_text(text)
    embeddings.append(vec)
    ids.append(row["id"])

if embeddings:
    mat = np.array(embeddings, dtype=np.float32)
    _index.add(mat)
    _id_map = {eid: i for i, eid in enumerate(ids)}
    _rev_id_map = {i: eid for i, eid in enumerate(ids)}

_save_index()
logger.info(f"Rebuilt FAISS index with {len(ids)} vectors")

def _add_to_index(entity_id: str, text: str):
"""Add a single vector to the FAISS index."""
global _id_map, _rev_id_map
index = _get_faiss_index()
vec = embed_text(text).reshape(1, -1)
idx = index.ntotal
index.add(vec)
_id_map[entity_id] = idx
_rev_id_map[idx] = entity_id
_save_index()

def _remove_from_index(entity_id: str):
"""Rebuild index when entities are removed/merged."""
# FAISS doesn't support removal easily; rebuild
_rebuild_index()

def _update_index(entity_id: str, text: str):
"""Update an entity's vector in the index."""
# FAISS doesn't support in-place update; rebuild
_rebuild_index()

def _save_index():
import faiss
faiss.write_index(_index, INDEX_PATH)
with open(ID_MAP_PATH, "w") as f:
json.dump(_id_map, f)

──────────────────────────────────────────────

Core API functions

──────────────────────────────────────────────

def add_entity(
type_: str,
subject: str,
description: str = "",
aliases: Optional[list] = None,
tags: Optional[list] = None,
confidence: float = 1.0,
source: str = "live",
) -> str:
"""Add an entity. Returns entity ID.

Validates type, normalizes subject, checks for duplicates via
semantic similarity + same type.
"""
validate_entity_type(type_)
validate_source(source)
validate_confidence(confidence)

subject = subject.strip()
if not subject:
    raise ValidationError("Subject must not be empty")

aliases = aliases or []
tags = tags or []
norm = normalize_subject(subject)

now = _now()
entity_id = _uuid()

with transaction() as conn:
    # Check for existing entity with same normalized subject + type
    existing = conn.execute(
        "SELECT id, subject, description, aliases, tags FROM entities WHERE type = ?",
        (type_,),
    ).fetchall()

    for ex in existing:
        if normalize_subject(ex["subject"]) == norm:
            # Same normalized subject + same type — definite duplicate
            # Use semantic similarity to confirm if descriptions match, 
            # but always merge if the subjects are identical after normalization
            logger.info(
                f"Duplicate detected (same subject + type): "
                f"'{ex['subject']}' == '{subject}', merging aliases/tags"
            )
            # Merge: update aliases, tags, description
            existing_aliases = set(json.loads(ex["aliases"]))
            existing_tags = set(json.loads(ex["tags"]))
            merged_aliases = existing_aliases | {subject} | set(aliases)
            merged_tags = existing_tags | set(tags)

            # Prefer longer description
            merged_desc = ex["description"]
            if len(description) > len(merged_desc):
                merged_desc = description

            conn.execute(
                """UPDATE entities
                   SET aliases = ?, tags = ?, description = ?,
                       updated_at = ?, confidence = MAX(confidence, ?)
                   WHERE id = ?""",
                (
                    json.dumps(sorted(merged_aliases)),
                    json.dumps(sorted(merged_tags)),
                    merged_desc,
                    now,
                    confidence,
                    ex["id"],
                ),
            )
            _update_index(ex["id"], f"{ex['subject']}: {merged_desc}")
            return ex["id"]

    conn.execute(
        """INSERT INTO entities (id, type, subject, description, aliases, tags,
                                 confidence, source, created_at, updated_at)
           VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
        (
            entity_id, type_, subject, description,
            json.dumps(aliases), json.dumps(tags),
            confidence, source, now, now,
        ),
    )

# Add to vector index
text = f"{subject}: {description}"
_add_to_index(entity_id, text)

return entity_id

def add_relation(
source_id: str,
target_id: str,
relation_type: str,
confidence: float = 1.0,
source: str = "live",
) -> str:
"""Add a relation between two entities. Returns relation ID.

Validates relation type (rejects vague types), checks both entities exist.
"""
validate_relation_type(relation_type)
validate_source(source)
validate_confidence(confidence)

now = _now()
relation_id = _uuid()

with transaction() as conn:
    # Verify both entities exist
    src = conn.execute("SELECT id FROM entities WHERE id = ?", (source_id,)).fetchone()
    if not src:
        raise ValidationError(f"Source entity {source_id} not found")
    tgt = conn.execute("SELECT id FROM entities WHERE id = ?", (target_id,)).fetchone()
    if not tgt:
        raise ValidationError(f"Target entity {target_id} not found")

    # Check for existing relation (unique constraint handles this)
    existing = conn.execute(
        """SELECT id FROM relations
           WHERE source_entity_id = ? AND target_entity_id = ? AND relation_type = ?""",
        (source_id, target_id, relation_type),
    ).fetchone()

    if existing:
        # Update confidence
        conn.execute(
            "UPDATE relations SET confidence = MAX(confidence, ?), updated_at = ? WHERE id = ?",
            (confidence, now, existing["id"]),
        )
        return existing["id"]

    conn.execute(
        """INSERT INTO relations (id, source_entity_id, target_entity_id,
                                  relation_type, confidence, source, created_at, updated_at)
           VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
        (relation_id, source_id, target_id, relation_type,
         confidence, source, now, now),
    )

return relation_id

def get_entity(entity_id: str) -> Optional[dict]:
"""Get full entity with all its relations."""
conn = get_connection()
row = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id,)).fetchone()
if not row:
return None

entity = dict(row)
entity["aliases"] = json.loads(entity["aliases"])
entity["tags"] = json.loads(entity["tags"])

# Get all relations where this entity participates
out_relations = conn.execute(
    """SELECT r.*, e.subject AS target_subject, e.type AS target_type
       FROM relations r JOIN entities e ON r.target_entity_id = e.id
       WHERE r.source_entity_id = ?""",
    (entity_id,),
).fetchall()

in_relations = conn.execute(
    """SELECT r.*, e.subject AS source_subject, e.type AS source_type
       FROM relations r JOIN entities e ON r.source_entity_id = e.id
       WHERE r.target_entity_id = ?""",
    (entity_id,),
).fetchall()

entity["relations_out"] = [dict(r) for r in out_relations]
entity["relations_in"] = [dict(r) for r in in_relations]

return entity

def search_entities(query: str, top_k: int = 5) -> list[dict]:
"""Semantic search over entity descriptions using FAISS."""
if top_k < 1:
top_k = 1

index = _get_faiss_index()
if index.ntotal == 0:
    return []

query_vec = embed_text(query).reshape(1, -1)
k = min(top_k, index.ntotal)
scores, indices = index.search(query_vec, k)

conn = get_connection()
results = []
for score, idx in zip(scores[0], indices[0]):
    eid = _rev_id_map.get(int(idx))
    if not eid:
        continue
    row = conn.execute(
        "SELECT id, type, subject, description, confidence, source FROM entities WHERE id = ?",
        (eid,),
    ).fetchone()
    if row:
        result = dict(row)
        result["similarity"] = float(score)
        results.append(result)

return results

def explore(entity_id: str, depth: int = 1) -> Optional[dict]:
"""Graph traversal: returns entity + neighbors up to depth hops.

Returns dict with root, nodes (neighbors by ID), and edges.
"""
if depth < 1:
    depth = 1

entity = get_entity(entity_id)
if not entity:
    return None

visited = {entity_id}
frontier = [entity_id]
all_nodes = {entity_id: entity}
all_edges = []

for _ in range(depth):
    next_frontier = []
    for fid in frontier:
        node = get_entity(fid)
        if not node:
            continue
        for rel in node.get("relations_out", []):
            tid = rel["target_entity_id"]
            all_edges.append(rel)
            if tid not in visited:
                visited.add(tid)
                next_frontier.append(tid)
                neighbor = get_entity(tid)
                if neighbor:
                    all_nodes[tid] = neighbor
        for rel in node.get("relations_in", []):
            sid = rel["source_entity_id"]
            all_edges.append(rel)
            if sid not in visited:
                visited.add(sid)
                next_frontier.append(sid)
                neighbor = get_entity(sid)
                if neighbor:
                    all_nodes[sid] = neighbor
    frontier = next_frontier

return {
    "root": entity,
    "nodes": {k: v for k, v in all_nodes.items() if k != entity_id},
    "edges": all_edges,
}

def merge_entities(entity_id_1: str, entity_id_2: str) -> dict:
"""Merge two entities. Entity 1 absorbs entity 2.
Re-points all relations from entity 2 to entity 1, then deletes entity 2.
"""
with transaction() as conn:
e1 = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id_1,)).fetchone()
e2 = conn.execute("SELECT * FROM entities WHERE id = ?", (entity_id_2,)).fetchone()

    if not e1:
        raise ValidationError(f"Entity {entity_id_1} not found")
    if not e2:
        raise ValidationError(f"Entity {entity_id_2} not found")

    # Merge metadata
    e1_aliases = set(json.loads(e1["aliases"]))
    e2_aliases = set(json.loads(e2["aliases"]))
    merged_aliases = e1_aliases | e2_aliases | {e1["subject"], e2["subject"]}

    e1_tags = set(json.loads(e1["tags"]))
    e2_tags = set(json.loads(e2["tags"]))
    merged_tags = e1_tags | e2_tags

    # Prefer longer/more recent description
    merged_desc = e1["description"] if len(e1["description"]) >= len(e2["description"]) else e2["description"]
    merged_confidence = max(e1["confidence"], e2["confidence"])
    now = _now()

    conn.execute(
        """UPDATE entities
           SET aliases = ?, tags = ?, description = ?,
               confidence = ?, updated_at = ?
           WHERE id = ?""",
        (
            json.dumps(sorted(merged_aliases)),
            json.dumps(sorted(merged_tags)),
            merged_desc, merged_confidence, now, entity_id_1,
        ),
    )

    # Re-point relations: entity_2 as source → entity_1
    conn.execute(
        """UPDATE OR IGNORE relations
           SET source_entity_id = ?, updated_at = ?
           WHERE source_entity_id = ?""",
        (entity_id_1, now, entity_id_2),
    )

    # Re-point relations: entity_2 as target → entity_1
    conn.execute(
        """UPDATE OR IGNORE relations
           SET target_entity_id = ?, updated_at = ?
           WHERE target_entity_id = ?""",
        (entity_id_1, now, entity_id_2),
    )

    # Delete relations that are now self-loops (entity → itself)
    conn.execute(
        "DELETE FROM relations WHERE source_entity_id = ? AND target_entity_id = ?",
        (entity_id_1, entity_id_1),
    )

    # Delete entity 2
    conn.execute("DELETE FROM entities WHERE id = ?", (entity_id_2,))

# Rebuild index after merge
_rebuild_index()

return {"merged_into": entity_id_1, "deleted": entity_id_2}

def decay_stale(days: int = 90, decay_rate: float = 0.1):
"""Decay confidence for old relations with source='dream' or 'extraction'."""
from datetime import datetime, timezone, timedelta
cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()

with transaction() as conn:
    updated = conn.execute(
        """UPDATE relations
           SET confidence = MAX(0.0, confidence - ?),
               updated_at = ?
           WHERE source IN ('dream', 'extraction')
           AND created_at < ?
           AND confidence > 0""",
        (decay_rate, _now(), cutoff),
    ).rowcount

return {"decayed_count": updated}

def infer_relations():
"""Placeholder: find co-occurring entities with no edge."""
conn = get_connection()
# Simple heuristic: entities with same tags but no relation
rows = conn.execute(
"""SELECT e1.id AS e1_id, e2.id AS e2_id, e1.tags, e2.tags
FROM entities e1
JOIN entities e2 ON e1.id < e2.id
WHERE e1.tags != '[]' AND e2.tags != '[]'
AND NOT EXISTS (
SELECT 1 FROM relations r
WHERE (r.source_entity_id = e1.id AND r.target_entity_id = e2.id)
OR (r.source_entity_id = e2.id AND r.target_entity_id = e1.id)
)"""
).fetchall()

candidates = []
for row in rows:
    t1 = set(json.loads(row["tags"]))
    t2 = set(json.loads(row["tags"]))
    common = t1 & t2
    if len(common) >= 2:
        candidates.append({
            "entity_id_1": row["e1_id"],
            "entity_id_2": row["e2_id"],
            "shared_tags": sorted(common),
        })

return {"candidates": candidates}

──────────────────────────────────────────────

Database initialization

──────────────────────────────────────────────

def init_db():
"""Initialize the database schema. Safe to call multiple times."""
conn = get_connection()
conn.executescript(SCHEMA_SQL)
conn.commit()
logger.info("Database schema initialized")

def health() -> dict:
"""Return health status of all subsystems."""
conn = get_connection()
entity_count = conn.execute("SELECT COUNT() FROM entities").fetchone()[0]
relation_count = conn.execute("SELECT COUNT(
) FROM relations").fetchone()[0]
index = _get_faiss_index()

# Check Ollama connectivity
ollama_ok = False
try:
    import httpx
    resp = httpx.get(f"{OLLAMA_HOST}/api/tags", timeout=5)
    ollama_ok = resp.status_code == 200
except Exception:
    pass

return {
    "status": "ok",
    "database": {
        "path": DB_PATH,
        "entity_count": entity_count,
        "relation_count": relation_count,
    },
    "vector_index": {
        "size": index.ntotal if index else 0,
        "dimension": _get_embedding_dim(),
    },
    "embedding": {
        "model": EMBEDDING_MODEL,
        "host": OLLAMA_HOST,
        "reachable": ollama_ok,
    },
    "controlled_vocabulary": {
        "valid_relation_types": sorted(VALID_RELATION_TYPES),
        "entity_types": sorted(VALID_ENTITY_TYPES),
    },
}

Run schema init on import

init_db()