diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 0c96fd1..035965d 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -46,6 +46,8 @@ import argparse # noqa: E402 (deferred until after stdio protection above) import json # noqa: E402 import logging # noqa: E402 import hashlib # noqa: E402 +import sqlite3 # noqa: E402 +import threading # noqa: E402 import time # noqa: E402 from datetime import date, datetime # noqa: E402 from pathlib import Path # noqa: E402 @@ -79,7 +81,7 @@ from .palace_graph import ( # noqa: E402 follow_tunnels, ) -from .knowledge_graph import KnowledgeGraph # noqa: E402 +from .knowledge_graph import KnowledgeGraph, DEFAULT_KG_PATH # noqa: E402 logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr) logger = logging.getLogger("mempalace_mcp") @@ -104,12 +106,61 @@ if _args.palace: os.environ["MEMPALACE_PALACE_PATH"] = os.path.abspath(_args.palace) _config = MempalaceConfig() -# Only override KG path when --palace is explicitly provided; otherwise use -# KnowledgeGraph's default (~/.mempalace/knowledge_graph.sqlite3). -if _args.palace: - _kg = KnowledgeGraph(db_path=os.path.join(_config.palace_path, "knowledge_graph.sqlite3")) -else: - _kg = KnowledgeGraph() + +_kg_by_path: dict[str, KnowledgeGraph] = {} +_kg_cache_lock = threading.Lock() +_palace_flag_given: bool = bool(_args.palace) + + +def _resolve_kg_path() -> str: + if _palace_flag_given: + return os.path.join(_config.palace_path, "knowledge_graph.sqlite3") + return DEFAULT_KG_PATH + + +def _get_kg() -> KnowledgeGraph: + path = os.path.abspath(_resolve_kg_path()) + kg = _kg_by_path.get(path) + if kg is not None: + return kg + with _kg_cache_lock: + kg = _kg_by_path.get(path) + if kg is None: + kg = KnowledgeGraph(db_path=path) + _kg_by_path[path] = kg + return kg + + +def _call_kg(op): + """Run ``op(kg)`` against the cached KG with one-shot retry on close. + + Race we're guarding against: a handler grabs ``kg = _get_kg()`` and is + about to call ``kg.add_triple(...)`` when ``tool_reconnect`` fires on + another thread, drains ``_kg_by_path``, and closes the underlying + sqlite3.Connection. The handler's call then raises + ``sqlite3.ProgrammingError: Cannot operate on a closed database`` and + bubbles up as a -32000 to the MCP client even though the user just + asked for a reconnect. + + Catch that single class of error, evict the stale entry from the + cache (only if it still points at the closed instance — another + thread may have already replaced it), and try once more with a fresh + KG. Beyond one retry give up: a second close means we're losing a + sustained race we won't win in this loop, and a hung loop is worse + than a clear failure surface. + """ + for attempt in range(2): + kg = _get_kg() + try: + return op(kg) + except sqlite3.ProgrammingError: + if attempt == 0: + path = os.path.abspath(_resolve_kg_path()) + with _kg_cache_lock: + if _kg_by_path.get(path) is kg: + _kg_by_path.pop(path, None) + continue + raise _client_cache = None @@ -1065,7 +1116,7 @@ def tool_kg_query(entity: str, as_of: str = None, direction: str = "both"): return {"error": str(e)} if direction not in ("outgoing", "incoming", "both"): return {"error": "direction must be 'outgoing', 'incoming', or 'both'"} - results = _kg.query_entity(entity, as_of=as_of, direction=direction) + results = _call_kg(lambda kg: kg.query_entity(entity, as_of=as_of, direction=direction)) return {"entity": entity, "as_of": as_of, "facts": results, "count": len(results)} @@ -1111,15 +1162,17 @@ def tool_kg_add( "source_drawer_id": source_drawer_id, }, ) - triple_id = _kg.add_triple( - subject, - predicate, - object, - valid_from=valid_from, - valid_to=valid_to, - source_closet=source_closet, - source_file=source_file, - source_drawer_id=source_drawer_id, + triple_id = _call_kg( + lambda kg: kg.add_triple( + subject, + predicate, + object, + valid_from=valid_from, + valid_to=valid_to, + source_closet=source_closet, + source_file=source_file, + source_drawer_id=source_drawer_id, + ) ) return {"success": True, "triple_id": triple_id, "fact": f"{subject} → {predicate} → {object}"} @@ -1151,7 +1204,7 @@ def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = N "ended": resolved_ended, }, ) - _kg.invalidate(subject, predicate, object, ended=resolved_ended) + _call_kg(lambda kg: kg.invalidate(subject, predicate, object, ended=resolved_ended)) return { "success": True, "fact": f"{subject} → {predicate} → {object}", @@ -1166,13 +1219,13 @@ def tool_kg_timeline(entity: str = None): entity = sanitize_kg_value(entity, "entity") except ValueError as e: return {"error": str(e)} - results = _kg.timeline(entity) + results = _call_kg(lambda kg: kg.timeline(entity)) return {"entity": entity or "all", "timeline": results, "count": len(results)} def tool_kg_stats(): """Knowledge graph overview: entities, triples, relationship types.""" - return _kg.stats() + return _call_kg(lambda kg: kg.stats()) # ==================== AGENT DIARY ==================== @@ -1404,10 +1457,11 @@ def tool_memories_filed_away(): def tool_reconnect(): - """Force the MCP server to drop the cached ChromaDB collection and reconnect. + """Force the MCP server to drop cached ChromaDB + KnowledgeGraph state. Use after external scripts or CLI commands modify the palace database - directly, which can leave the in-memory HNSW index stale. + or replace ``knowledge_graph.sqlite3`` directly, which can leave the + in-memory HNSW index stale or pin a closed-on-disk SQLite connection. """ global \ _client_cache, \ @@ -1425,6 +1479,15 @@ def tool_reconnect(): # still applies after the reconnect. _vector_disabled = False _vector_disabled_reason = "" + # Drain the per-path KnowledgeGraph cache so a replaced sqlite file is + # reopened on the next tool call rather than served from a stale handle. + with _kg_cache_lock: + for kg in _kg_by_path.values(): + try: + kg.close() + except Exception: + pass + _kg_by_path.clear() try: col = _get_collection() if col is None: diff --git a/tests/benchmarks/test_mcp_bench.py b/tests/benchmarks/test_mcp_bench.py index 4e8330b..42e73ec 100644 --- a/tests/benchmarks/test_mcp_bench.py +++ b/tests/benchmarks/test_mcp_bench.py @@ -40,8 +40,9 @@ def _patch_mcp_config(monkeypatch, palace_path, tmp_path): import mempalace.mcp_server as mcp_mod + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) monkeypatch.setattr(mcp_mod, "_config", cfg) - monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))) + monkeypatch.setattr(mcp_mod, "_get_kg", lambda: kg) def _get_rss_mb(): diff --git a/tests/benchmarks/test_memory_profile.py b/tests/benchmarks/test_memory_profile.py index b299b2d..047bfaa 100644 --- a/tests/benchmarks/test_memory_profile.py +++ b/tests/benchmarks/test_memory_profile.py @@ -84,8 +84,9 @@ class TestToolStatusMemoryProfile: cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg")) monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path}) + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) monkeypatch.setattr(mcp_mod, "_config", cfg) - monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))) + monkeypatch.setattr(mcp_mod, "_get_kg", lambda: kg) from mempalace.mcp_server import tool_status diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index e9caca1..5c92dcd 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -8,6 +8,7 @@ via monkeypatch to avoid touching real data. from datetime import datetime import json +import os import sys import pytest @@ -18,7 +19,7 @@ def _patch_mcp_server(monkeypatch, config, kg): from mempalace import mcp_server monkeypatch.setattr(mcp_server, "_config", config) - monkeypatch.setattr(mcp_server, "_kg", kg) + monkeypatch.setattr(mcp_server, "_get_kg", lambda: kg) def _get_collection(palace_path, create=False): @@ -1196,3 +1197,231 @@ class TestCacheInvalidation: for kwargs in captured["get"]: assert "embedding_function" in kwargs assert kwargs["embedding_function"] is not None + + +class TestKGLazyCache: + """Lazy per-path KnowledgeGraph cache (issue #1136).""" + + def test_lazy_init_no_import_side_effect(self, tmp_path): + """Importing mcp_server must not create knowledge_graph.sqlite3. + + Runs in a fresh subprocess with HOME pointed at tmp_path so the + assertion targets a clean filesystem, independent of conftest's + session-level HOME patch. + """ + import subprocess + import sys + + kg_file = tmp_path / ".mempalace" / "knowledge_graph.sqlite3" + env = {k: v for k, v in os.environ.items() if not k.startswith("MEMPAL")} + env["HOME"] = str(tmp_path) + env["USERPROFILE"] = str(tmp_path) + result = subprocess.run( + [sys.executable, "-c", "import mempalace.mcp_server"], + env=env, + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0, f"import failed: {result.stderr}" + assert not kg_file.exists(), f"import created sqlite file at {kg_file} as a side effect" + + def test_get_kg_returns_same_instance(self, tmp_path, monkeypatch): + """Two calls with the same resolved path return the same KG.""" + from mempalace import mcp_server + + monkeypatch.setattr(mcp_server, "_kg_by_path", {}) + monkeypatch.setattr(mcp_server, "_palace_flag_given", True) + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_path)) + + kg1 = mcp_server._get_kg() + kg2 = mcp_server._get_kg() + assert kg1 is kg2 + assert len(mcp_server._kg_by_path) == 1 + + def test_get_kg_different_paths_different_instances(self, tmp_path, monkeypatch): + """Different palace paths map to different KG instances.""" + from mempalace import mcp_server + + tmp_a = tmp_path / "a" + tmp_b = tmp_path / "b" + tmp_a.mkdir() + tmp_b.mkdir() + + monkeypatch.setattr(mcp_server, "_kg_by_path", {}) + monkeypatch.setattr(mcp_server, "_palace_flag_given", True) + + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_a)) + kg_a = mcp_server._get_kg() + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_b)) + kg_b = mcp_server._get_kg() + + assert kg_a is not kg_b + assert len(mcp_server._kg_by_path) == 2 + + def test_multi_tenant_env_switch(self, tmp_path, monkeypatch): + """The issue #1136 acceptance scenario. + + Rotating MEMPALACE_PALACE_PATH between MCP tool calls must route + each call to the correct tenant's KG sqlite file. + """ + from mempalace import mcp_server + + tmp_a = tmp_path / "tenant_a" + tmp_b = tmp_path / "tenant_b" + tmp_a.mkdir() + tmp_b.mkdir() + + monkeypatch.setattr(mcp_server, "_kg_by_path", {}) + monkeypatch.setattr(mcp_server, "_palace_flag_given", True) + + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_a)) + add_result = mcp_server.tool_kg_add( + subject="alice_secret", + predicate="owns", + object="repo_a", + ) + assert add_result.get("success") is True, add_result + + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_b)) + query_b = mcp_server.tool_kg_query(entity="alice_secret") + assert query_b.get("count", 0) == 0, f"tenant B leaked tenant A's fact: {query_b}" + + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_a)) + query_a = mcp_server.tool_kg_query(entity="alice_secret") + assert query_a.get("count", 0) >= 1, f"tenant A lost its own fact: {query_a}" + + def test_cache_thread_safe(self, tmp_path, monkeypatch): + """Concurrent _get_kg() for the same path yields one instance.""" + import concurrent.futures + from mempalace import mcp_server + + monkeypatch.setattr(mcp_server, "_kg_by_path", {}) + monkeypatch.setattr(mcp_server, "_palace_flag_given", True) + monkeypatch.setenv("MEMPALACE_PALACE_PATH", str(tmp_path)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as pool: + results = list(pool.map(lambda _: mcp_server._get_kg(), range(16))) + + ids = {id(kg) for kg in results} + assert len(ids) == 1, f"expected 1 unique instance, got {len(ids)}" + assert len(mcp_server._kg_by_path) == 1 + + def test_tool_reconnect_drains_kg_cache(self, monkeypatch): + """``tool_reconnect`` must close cached KG instances and clear the dict. + + Without this, an external replacement of ``knowledge_graph.sqlite3`` + leaves the server pinned to a stale ``sqlite3.Connection``. + """ + from mempalace import mcp_server + + class _FakeKG: + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + fake_a = _FakeKG() + fake_b = _FakeKG() + monkeypatch.setattr(mcp_server, "_kg_by_path", {"/a": fake_a, "/b": fake_b}) + # Bypass real ChromaDB so the test isolates KG-cache behaviour. + monkeypatch.setattr(mcp_server, "_get_collection", lambda: None) + + mcp_server.tool_reconnect() + + assert fake_a.closed is True + assert fake_b.closed is True + assert mcp_server._kg_by_path == {} + + def test_tool_reconnect_swallows_kg_close_errors(self, monkeypatch): + """A failing ``close()`` on one cached KG must not block cache clearing.""" + from mempalace import mcp_server + + class _BoomKG: + def close(self): + raise RuntimeError("boom") + + monkeypatch.setattr(mcp_server, "_kg_by_path", {"/a": _BoomKG()}) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: None) + + mcp_server.tool_reconnect() + + assert mcp_server._kg_by_path == {} + + def test_call_kg_retries_after_concurrent_close(self, monkeypatch): + """A KG closed mid-handler must trigger a one-shot retry with a fresh + instance — not surface a -32000 to the MCP client.""" + import sqlite3 as _sqlite3 + + from mempalace import mcp_server + + path = "/fake/palace/knowledge_graph.sqlite3" + monkeypatch.setattr(mcp_server, "_resolve_kg_path", lambda: path) + + class _ClosedKG: + def query_entity(self, entity, **kwargs): + raise _sqlite3.ProgrammingError("Cannot operate on a closed database") + + class _FreshKG: + def query_entity(self, entity, **kwargs): + return [{"entity": entity}] + + cache = {os.path.abspath(path): _ClosedKG()} + monkeypatch.setattr(mcp_server, "_kg_by_path", cache) + + # Second _get_kg() call (after the cache eviction) constructs a new + # KG. Patch the constructor so we don't open a real sqlite file. + monkeypatch.setattr(mcp_server, "KnowledgeGraph", lambda **_: _FreshKG()) + + result = mcp_server._call_kg(lambda kg: kg.query_entity("Alice")) + assert result == [{"entity": "Alice"}] + # The closed instance must be evicted; the fresh one must be cached. + assert isinstance(cache[os.path.abspath(path)], _FreshKG) + + def test_call_kg_does_not_retry_on_other_errors(self, monkeypatch): + """Non-ProgrammingError exceptions must propagate without retry — + we don't want the retry guard masking real bugs.""" + from mempalace import mcp_server + + path = "/fake/palace/knowledge_graph.sqlite3" + monkeypatch.setattr(mcp_server, "_resolve_kg_path", lambda: path) + + calls = {"count": 0} + + class _FailingKG: + def query_entity(self, entity, **kwargs): + calls["count"] += 1 + raise ValueError("bad input") + + monkeypatch.setattr(mcp_server, "_kg_by_path", {os.path.abspath(path): _FailingKG()}) + monkeypatch.setattr(mcp_server, "KnowledgeGraph", lambda **_: _FailingKG()) + + with pytest.raises(ValueError, match="bad input"): + mcp_server._call_kg(lambda kg: kg.query_entity("Alice")) + assert calls["count"] == 1, "non-ProgrammingError must not trigger retry" + + def test_call_kg_gives_up_after_one_retry(self, monkeypatch): + """If the second attempt also hits a closed DB, give up rather than + loop forever — a sustained close-stream is a different bug.""" + import sqlite3 as _sqlite3 + + from mempalace import mcp_server + + path = "/fake/palace/knowledge_graph.sqlite3" + monkeypatch.setattr(mcp_server, "_resolve_kg_path", lambda: path) + + calls = {"count": 0} + + class _AlwaysClosedKG: + def query_entity(self, entity, **kwargs): + calls["count"] += 1 + raise _sqlite3.ProgrammingError("closed again") + + cache = {} + monkeypatch.setattr(mcp_server, "_kg_by_path", cache) + monkeypatch.setattr(mcp_server, "KnowledgeGraph", lambda **_: _AlwaysClosedKG()) + + with pytest.raises(_sqlite3.ProgrammingError): + mcp_server._call_kg(lambda kg: kg.query_entity("Alice")) + assert calls["count"] == 2, "expected exactly one retry beyond the initial attempt"