From 267a644f4f34e64ea4161fc24b4dcdfdb3652bd5 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Tue, 14 Apr 2026 00:31:16 -0300 Subject: [PATCH] refactor: route all chromadb access through ChromaBackend Prerequisite for RFC 001 (plugin spec, #743). Removes every direct `import chromadb` outside the ChromaDB backend itself so the core modules depend only on the backend abstraction layer. Extends ChromaBackend with make_client, get_or_create_collection, delete_collection, create_collection, and backend_version. Adds update() to the BaseCollection contract. Non-backend callers (mcp_server, dedup, repair, migrate, cli) now go through the abstraction; tests patch ChromaBackend instead of chromadb. With this landed, the RFC 001 spec can be enforced and PalaceStore (#643) can ship as a plugin without touching core modules. --- mempalace/backends/base.py | 5 ++ mempalace/backends/chroma.py | 63 ++++++++++++++++++- mempalace/cli.py | 21 +++---- mempalace/dedup.py | 8 +-- mempalace/mcp_server.py | 12 ++-- mempalace/migrate.py | 18 +++--- mempalace/repair.py | 16 +++-- tests/test_cli.py | 106 +++++++++++++------------------- tests/test_dedup.py | 39 ++++++------ tests/test_repair.py | 114 ++++++++++++++++------------------- uv.lock | 2 +- 11 files changed, 215 insertions(+), 189 deletions(-) diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py index 4685f51..877da53 100644 --- a/mempalace/backends/base.py +++ b/mempalace/backends/base.py @@ -27,6 +27,11 @@ class BaseCollection(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def update(self, **kwargs: Any) -> None: + """Update existing records. Must raise if any ID is missing.""" + raise NotImplementedError + @abstractmethod def query(self, **kwargs: Any) -> Dict[str, Any]: raise NotImplementedError diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index 28fe55f..1a13675 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -55,6 +55,9 @@ class ChromaCollection(BaseCollection): def upsert(self, *, documents, ids, metadatas=None): self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) + def update(self, **kwargs): + self._collection.update(**kwargs) + def query(self, **kwargs): return self._collection.query(**kwargs) @@ -71,6 +74,44 @@ class ChromaCollection(BaseCollection): class ChromaBackend: """Factory for MemPalace's default ChromaDB backend.""" + def __init__(self): + # Per-instance client cache: palace_path -> chromadb.PersistentClient + self._clients: dict = {} + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _client(self, palace_path: str): + """Return a cached PersistentClient for *palace_path*, creating one if needed.""" + if palace_path not in self._clients: + _fix_blob_seq_ids(palace_path) + self._clients[palace_path] = chromadb.PersistentClient(path=palace_path) + return self._clients[palace_path] + + # ------------------------------------------------------------------ + # Public static helpers (for callers that manage their own caching) + # ------------------------------------------------------------------ + + @staticmethod + def make_client(palace_path: str): + """Create and return a fresh PersistentClient (fix BLOB seq_ids first). + + Intended for long-lived callers (e.g. mcp_server) that keep their own + inode/mtime-based client cache. + """ + _fix_blob_seq_ids(palace_path) + return chromadb.PersistentClient(path=palace_path) + + @staticmethod + def backend_version() -> str: + """Return the installed chromadb package version string.""" + return chromadb.__version__ + + # ------------------------------------------------------------------ + # Collection lifecycle + # ------------------------------------------------------------------ + def get_collection(self, palace_path: str, collection_name: str, create: bool = False): if not create and not os.path.isdir(palace_path): raise FileNotFoundError(palace_path) @@ -82,8 +123,7 @@ class ChromaBackend: except (OSError, NotImplementedError): pass - _fix_blob_seq_ids(palace_path) - client = chromadb.PersistentClient(path=palace_path) + client = self._client(palace_path) if create: collection = client.get_or_create_collection( collection_name, metadata={"hnsw:space": "cosine"} @@ -91,3 +131,22 @@ class ChromaBackend: else: collection = client.get_collection(collection_name) return ChromaCollection(collection) + + def get_or_create_collection( + self, palace_path: str, collection_name: str + ) -> "ChromaCollection": + """Shorthand for get_collection(..., create=True).""" + return self.get_collection(palace_path, collection_name, create=True) + + def delete_collection(self, palace_path: str, collection_name: str) -> None: + """Delete *collection_name* from the palace at *palace_path*.""" + self._client(palace_path).delete_collection(collection_name) + + def create_collection( + self, palace_path: str, collection_name: str, hnsw_space: str = "cosine" + ) -> "ChromaCollection": + """Create (not get-or-create) *collection_name* with cosine HNSW space.""" + collection = self._client(palace_path).create_collection( + collection_name, metadata={"hnsw:space": hnsw_space} + ) + return ChromaCollection(collection) diff --git a/mempalace/cli.py b/mempalace/cli.py index fa92ed6..f7f68d7 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -172,8 +172,8 @@ def cmd_status(args): def cmd_repair(args): """Rebuild palace vector index from SQLite metadata.""" - import chromadb import shutil + from .backends.chroma import ChromaBackend from .migrate import confirm_destructive_action, contains_palace_database palace_path = os.path.abspath( @@ -193,10 +193,11 @@ def cmd_repair(args): print(f"{'=' * 55}\n") print(f" Palace: {palace_path}") + backend = ChromaBackend() + # Try to read existing drawers try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = backend.get_collection(palace_path, "mempalace_drawers") total = col.count() print(f" Drawers found: {total}") except Exception as e: @@ -243,8 +244,8 @@ def cmd_repair(args): shutil.copytree(palace_path, backup_path) print(" Rebuilding collection...") - client.delete_collection("mempalace_drawers") - new_col = client.create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) + backend.delete_collection(palace_path, "mempalace_drawers") + new_col = backend.create_collection(palace_path, "mempalace_drawers") filed = 0 for i in range(0, len(all_ids), batch_size): @@ -297,7 +298,7 @@ def cmd_mcp(args): def cmd_compress(args): """Compress drawers in a wing using AAAK Dialect.""" - import chromadb + from .backends.chroma import ChromaBackend from .dialect import Dialect palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path @@ -317,9 +318,9 @@ def cmd_compress(args): dialect = Dialect() # Connect to palace + backend = ChromaBackend() try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = backend.get_collection(palace_path, "mempalace_drawers") except Exception: print(f"\n No palace found at {palace_path}") print(" Run: mempalace init then mempalace mine ") @@ -394,9 +395,7 @@ def cmd_compress(args): # Store compressed versions (unless dry-run) if not args.dry_run: try: - comp_col = client.get_or_create_collection( - "mempalace_compressed", metadata={"hnsw:space": "cosine"} - ) + comp_col = backend.get_or_create_collection(palace_path, "mempalace_compressed") for doc_id, compressed, meta, stats in compressed_entries: comp_meta = dict(meta) comp_meta["compression_ratio"] = round(stats["size_ratio"], 1) diff --git a/mempalace/dedup.py b/mempalace/dedup.py index c2f9f6b..6b1bac1 100644 --- a/mempalace/dedup.py +++ b/mempalace/dedup.py @@ -27,7 +27,7 @@ import os import time from collections import defaultdict -import chromadb +from .backends.chroma import ChromaBackend COLLECTION_NAME = "mempalace_drawers" @@ -130,8 +130,7 @@ def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=Tru def show_stats(palace_path=None): """Show duplication statistics without making changes.""" palace_path = palace_path or _get_palace_path() - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection(COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) groups = get_source_groups(col) @@ -163,8 +162,7 @@ def dedup_palace( print(" MemPalace Deduplicator") print(f"{'=' * 55}") - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection(COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) print(f" Palace: {palace_path}") print(f" Drawers: {col.count():,}") diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 31be8a4..4653f5f 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -32,7 +32,7 @@ from pathlib import Path from .config import MempalaceConfig, sanitize_name, sanitize_content from .version import __version__ -import chromadb +from .backends.chroma import ChromaBackend, ChromaCollection from .query_sanitizer import sanitize_query from .searcher import search_memories from .palace_graph import ( @@ -177,7 +177,7 @@ def _get_client(): mtime_changed = current_mtime != 0.0 and abs(current_mtime - _palace_db_mtime) > 0.01 if _client_cache is None or inode_changed or mtime_changed: - _client_cache = chromadb.PersistentClient(path=_config.palace_path) + _client_cache = ChromaBackend.make_client(_config.palace_path) _collection_cache = None _metadata_cache = None _metadata_cache_time = 0 @@ -192,13 +192,15 @@ def _get_collection(create=False): try: client = _get_client() if create: - _collection_cache = client.get_or_create_collection( - _config.collection_name, metadata={"hnsw:space": "cosine"} + _collection_cache = ChromaCollection( + client.get_or_create_collection( + _config.collection_name, metadata={"hnsw:space": "cosine"} + ) ) _metadata_cache = None _metadata_cache_time = 0 elif _collection_cache is None: - _collection_cache = client.get_collection(_config.collection_name) + _collection_cache = ChromaCollection(client.get_collection(_config.collection_name)) _metadata_cache = None _metadata_cache_time = 0 return _collection_cache diff --git a/mempalace/migrate.py b/mempalace/migrate.py index 319c670..2eebb61 100644 --- a/mempalace/migrate.py +++ b/mempalace/migrate.py @@ -134,7 +134,7 @@ def confirm_destructive_action( def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): """Migrate a palace to the currently installed ChromaDB version.""" - import chromadb + from .backends.chroma import ChromaBackend palace_path = os.path.abspath(os.path.expanduser(palace_path)) db_path = os.path.join(palace_path, "chroma.sqlite3") @@ -152,19 +152,19 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): # Detect version source_version = detect_chromadb_version(db_path) + target_version = ChromaBackend.backend_version() print(f" Source: ChromaDB {source_version}") - print(f" Target: ChromaDB {chromadb.__version__}") + print(f" Target: ChromaDB {target_version}") # Try reading with current chromadb first try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = ChromaBackend().get_collection(palace_path, "mempalace_drawers") count = col.count() - print(f"\n Palace is already readable by chromadb {chromadb.__version__}.") + print(f"\n Palace is already readable by chromadb {target_version}.") print(f" {count} drawers found. No migration needed.") return True except Exception: - print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.") + print(f"\n Palace is NOT readable by chromadb {target_version}.") print(" Extracting from SQLite directly...") # Extract all drawers via raw SQL @@ -208,8 +208,8 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") print(f" Creating fresh palace in {temp_palace}...") - client = chromadb.PersistentClient(path=temp_palace) - col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) + fresh_backend = ChromaBackend() + col = fresh_backend.get_or_create_collection(temp_palace, "mempalace_drawers") # Re-import in batches batch_size = 500 @@ -227,7 +227,7 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): # Verify before swapping final_count = col.count() del col - del client + del fresh_backend # Swap: remove old palace, move new one into place print(" Swapping old palace for migrated version...") diff --git a/mempalace/repair.py b/mempalace/repair.py index d51be60..9a9aa88 100644 --- a/mempalace/repair.py +++ b/mempalace/repair.py @@ -32,7 +32,7 @@ import os import shutil import time -import chromadb +from .backends.chroma import ChromaBackend COLLECTION_NAME = "mempalace_drawers" @@ -90,8 +90,7 @@ def scan_palace(palace_path=None, only_wing=None): print(f"\n Palace: {palace_path}") print(" Loading...") - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection(COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) where = {"wing": only_wing} if only_wing else None total = col.count() @@ -174,8 +173,7 @@ def prune_corrupt(palace_path=None, confirm=False): print(" Re-run with --confirm to actually delete.") return - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection(COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) before = col.count() print(f" Collection size before: {before:,}") @@ -222,9 +220,9 @@ def rebuild_index(palace_path=None): print(f"{'=' * 55}\n") print(f" Palace: {palace_path}") - client = chromadb.PersistentClient(path=palace_path) + backend = ChromaBackend() try: - col = client.get_collection(COLLECTION_NAME) + col = backend.get_collection(palace_path, COLLECTION_NAME) total = col.count() except Exception as e: print(f" Error reading palace: {e}") @@ -264,8 +262,8 @@ def rebuild_index(palace_path=None): # Rebuild with correct HNSW settings print(" Rebuilding collection with hnsw:space=cosine...") - client.delete_collection(COLLECTION_NAME) - new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) + backend.delete_collection(palace_path, COLLECTION_NAME) + new_col = backend.create_collection(palace_path, COLLECTION_NAME) filed = 0 for i in range(0, len(all_ids), batch_size): diff --git a/tests/test_cli.py b/tests/test_cli.py index 0e95a8c..c4b4203 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -412,12 +412,21 @@ def test_main_compress_dispatches(): # ── cmd_repair ───────────────────────────────────────────────────────── +def _mock_backend_for(col=None, new_col=None): + """Build a mock ChromaBackend whose get_collection/create_collection return *col* / *new_col*.""" + mock_backend = MagicMock() + if col is not None: + mock_backend.get_collection.return_value = col + if new_col is not None: + mock_backend.create_collection.return_value = new_col + return mock_backend + + @patch("mempalace.cli.MempalaceConfig") def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys): mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent") args = argparse.Namespace(palace=None) - mock_chromadb = MagicMock() - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + with patch("mempalace.backends.chroma.ChromaBackend"): cmd_repair(args) out = capsys.readouterr().out assert "No palace found" in out @@ -429,8 +438,7 @@ def test_cmd_repair_requires_palace_database(mock_config_cls, tmp_path, capsys): palace_dir.mkdir() mock_config_cls.return_value.palace_path = str(palace_dir) args = argparse.Namespace(palace=None) - mock_chromadb = MagicMock() - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + with patch("mempalace.backends.chroma.ChromaBackend"): cmd_repair(args) out = capsys.readouterr().out assert "No palace database found" in out @@ -443,11 +451,9 @@ def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys): (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) args = argparse.Namespace(palace=None) - mock_chromadb = MagicMock() - mock_client = MagicMock() - mock_client.get_collection.side_effect = Exception("corrupt db") - mock_chromadb.PersistentClient.return_value = mock_client - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + mock_backend = MagicMock() + mock_backend.get_collection.side_effect = Exception("corrupt db") + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): cmd_repair(args) out = capsys.readouterr().out assert "Error reading palace" in out @@ -460,13 +466,10 @@ def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys): (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) args = argparse.Namespace(palace=None) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.count.return_value = 0 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + mock_backend = _mock_backend_for(col=mock_col) + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): cmd_repair(args) out = capsys.readouterr().out assert "Nothing to repair" in out @@ -479,7 +482,6 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) args = argparse.Namespace(palace=None, yes=True) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.count.return_value = 2 mock_col.get.return_value = { @@ -487,12 +489,9 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): "documents": ["doc1", "doc2"], "metadatas": [{"wing": "a"}, {"wing": "b"}], } - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col mock_new_col = MagicMock() - mock_client.create_collection.return_value = mock_new_col - mock_chromadb.PersistentClient.return_value = mock_client - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + mock_backend = _mock_backend_for(col=mock_col, new_col=mock_new_col) + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): cmd_repair(args) out = capsys.readouterr().out assert "Repair complete" in out @@ -506,20 +505,17 @@ def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsy (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) args = argparse.Namespace(palace=None) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.count.return_value = 1 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = _mock_backend_for(col=mock_col) with ( - patch.dict("sys.modules", {"chromadb": mock_chromadb}), + patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), patch("builtins.input", return_value="n"), ): cmd_repair(args) out = capsys.readouterr().out assert "Aborted." in out - mock_client.create_collection.assert_not_called() + mock_backend.create_collection.assert_not_called() # ── cmd_compress ─────────────────────────────────────────────────────── @@ -529,10 +525,10 @@ def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsy def test_cmd_compress_no_palace(mock_config_cls, capsys): mock_config_cls.return_value.palace_path = "/fake/palace" args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) - mock_chromadb = MagicMock() - mock_chromadb.PersistentClient.side_effect = Exception("no palace") + mock_backend = MagicMock() + mock_backend.get_collection.side_effect = Exception("no palace") with ( - patch.dict("sys.modules", {"chromadb": mock_chromadb}), + patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), pytest.raises(SystemExit), ): cmd_compress(args) @@ -542,13 +538,10 @@ def test_cmd_compress_no_palace(mock_config_cls, capsys): def test_cmd_compress_no_drawers(mock_config_cls, capsys): mock_config_cls.return_value.palace_path = "/fake/palace" args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client - with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + mock_backend = _mock_backend_for(col=mock_col) + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): cmd_compress(args) out = capsys.readouterr().out assert "No drawers found" in out @@ -567,7 +560,6 @@ def _make_mock_dialect_module(dialect_instance): def test_cmd_compress_dry_run(mock_config_cls, capsys): mock_config_cls.return_value.palace_path = "/fake/palace" args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.get.side_effect = [ { @@ -577,9 +569,7 @@ def test_cmd_compress_dry_run(mock_config_cls, capsys): }, {"documents": [], "metadatas": [], "ids": []}, ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = _mock_backend_for(col=mock_col) mock_dialect = MagicMock() mock_dialect.compress.return_value = "compressed" @@ -593,12 +583,9 @@ def test_cmd_compress_dry_run(mock_config_cls, capsys): } mock_dialect_mod = _make_mock_dialect_module(mock_dialect) - with patch.dict( - "sys.modules", - { - "chromadb": mock_chromadb, - "mempalace.dialect": mock_dialect_mod, - }, + with ( + patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), + patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}), ): cmd_compress(args) out = capsys.readouterr().out @@ -613,22 +600,16 @@ def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys): config_file = tmp_path / "entities.json" config_file.write_text('{"people": [], "projects": []}') args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file)) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = _mock_backend_for(col=mock_col) mock_dialect = MagicMock() mock_dialect_mod = _make_mock_dialect_module(mock_dialect) - with patch.dict( - "sys.modules", - { - "chromadb": mock_chromadb, - "mempalace.dialect": mock_dialect_mod, - }, + with ( + patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), + patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}), ): cmd_compress(args) out = capsys.readouterr().out @@ -640,7 +621,6 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys): """Non-dry-run compress stores to mempalace_compressed collection.""" mock_config_cls.return_value.palace_path = "/fake/palace" args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) - mock_chromadb = MagicMock() mock_col = MagicMock() mock_col.get.side_effect = [ { @@ -650,11 +630,10 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys): }, {"documents": [], "metadatas": [], "ids": []}, ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col mock_comp_col = MagicMock() - mock_client.get_or_create_collection.return_value = mock_comp_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = MagicMock() + mock_backend.get_collection.return_value = mock_col + mock_backend.get_or_create_collection.return_value = mock_comp_col mock_dialect = MagicMock() mock_dialect.compress.return_value = "compressed" @@ -668,12 +647,9 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys): } mock_dialect_mod = _make_mock_dialect_module(mock_dialect) - with patch.dict( - "sys.modules", - { - "chromadb": mock_chromadb, - "mempalace.dialect": mock_dialect_mod, - }, + with ( + patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), + patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}), ): cmd_compress(args) out = capsys.readouterr().out diff --git a/tests/test_dedup.py b/tests/test_dedup.py index 2ddffb3..dfdd3de 100644 --- a/tests/test_dedup.py +++ b/tests/test_dedup.py @@ -198,8 +198,15 @@ def test_dedup_source_group_query_failure_keeps(): # ── show_stats ──────────────────────────────────────────────────────── -@patch("mempalace.dedup.chromadb") -def test_show_stats(mock_chromadb, tmp_path): +def _install_mock_backend(mock_backend_cls, collection): + mock_backend = MagicMock() + mock_backend.get_collection.return_value = collection + mock_backend_cls.return_value = mock_backend + return mock_backend + + +@patch("mempalace.dedup.ChromaBackend") +def test_show_stats(mock_backend_cls, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 5 mock_col.get.side_effect = [ @@ -215,9 +222,7 @@ def test_show_stats(mock_chromadb, tmp_path): }, {"ids": []}, ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) dedup.show_stats(palace_path=str(tmp_path)) # should not raise @@ -227,13 +232,11 @@ def test_show_stats(mock_chromadb, tmp_path): @patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.get_source_groups") -@patch("mempalace.dedup.chromadb") -def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): +@patch("mempalace.dedup.ChromaBackend") +def test_dedup_palace_dry_run(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 10 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]} mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"]) @@ -244,13 +247,11 @@ def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_ @patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.get_source_groups") -@patch("mempalace.dedup.chromadb") -def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): +@patch("mempalace.dedup.ChromaBackend") +def test_dedup_palace_with_wing(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 10 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) mock_groups.return_value = {} dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True) @@ -259,13 +260,11 @@ def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tm @patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.get_source_groups") -@patch("mempalace.dedup.chromadb") -def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): +@patch("mempalace.dedup.ChromaBackend") +def test_dedup_palace_no_groups(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 3 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) mock_groups.return_value = {} dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) diff --git a/tests/test_repair.py b/tests/test_repair.py index 604b0fb..9ae1812 100644 --- a/tests/test_repair.py +++ b/tests/test_repair.py @@ -66,22 +66,28 @@ def test_paginate_ids_offset_exception_fallback(): # ── scan_palace ─────────────────────────────────────────────────────── -@patch("mempalace.repair.chromadb") -def test_scan_palace_no_ids(mock_chromadb, tmp_path): +def _install_mock_backend(mock_backend_cls, collection): + """Wire mock_backend_cls so ChromaBackend().get_collection(...) returns *collection*.""" + mock_backend = MagicMock() + mock_backend.get_collection.return_value = collection + mock_backend_cls.return_value = mock_backend + return mock_backend + + +@patch("mempalace.repair.ChromaBackend") +def test_scan_palace_no_ids(mock_backend_cls, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 0 mock_col.get.return_value = {"ids": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) good, bad = repair.scan_palace(palace_path=str(tmp_path)) assert good == set() assert bad == set() -@patch("mempalace.repair.chromadb") -def test_scan_palace_all_good(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_scan_palace_all_good(mock_backend_cls, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 2 # _paginate_ids call @@ -89,9 +95,7 @@ def test_scan_palace_all_good(mock_chromadb, tmp_path): {"ids": ["id1", "id2"]}, # paginate {"ids": ["id1", "id2"]}, # probe batch — both returned ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) good, bad = repair.scan_palace(palace_path=str(tmp_path)) assert "id1" in good @@ -99,8 +103,8 @@ def test_scan_palace_all_good(mock_chromadb, tmp_path): assert len(bad) == 0 -@patch("mempalace.repair.chromadb") -def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_scan_palace_with_bad_ids(mock_backend_cls, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 2 @@ -117,26 +121,22 @@ def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path): raise Exception("batch fail") mock_col.get.side_effect = get_side_effect - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) good, bad = repair.scan_palace(palace_path=str(tmp_path)) assert "good1" in good assert "bad1" in bad -@patch("mempalace.repair.chromadb") -def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_scan_palace_with_wing_filter(mock_backend_cls, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 1 mock_col.get.side_effect = [ {"ids": ["id1"]}, # paginate {"ids": ["id1"]}, # probe ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing") # Verify where filter was passed @@ -147,38 +147,36 @@ def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path): # ── prune_corrupt ───────────────────────────────────────────────────── -@patch("mempalace.repair.chromadb") -def test_prune_corrupt_no_file(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_prune_corrupt_no_file(mock_backend_cls, tmp_path): # Should print message and return without error repair.prune_corrupt(palace_path=str(tmp_path)) -@patch("mempalace.repair.chromadb") -def test_prune_corrupt_dry_run(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_prune_corrupt_dry_run(mock_backend_cls, tmp_path): bad_file = tmp_path / "corrupt_ids.txt" bad_file.write_text("bad1\nbad2\n") repair.prune_corrupt(palace_path=str(tmp_path), confirm=False) - # No chromadb calls in dry run - mock_chromadb.PersistentClient.assert_not_called() + # No backend calls in dry run + mock_backend_cls.assert_not_called() -@patch("mempalace.repair.chromadb") -def test_prune_corrupt_confirmed(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_prune_corrupt_confirmed(mock_backend_cls, tmp_path): bad_file = tmp_path / "corrupt_ids.txt" bad_file.write_text("bad1\nbad2\n") mock_col = MagicMock() mock_col.count.side_effect = [10, 8] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) mock_col.delete.assert_called_once() -@patch("mempalace.repair.chromadb") -def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_prune_corrupt_delete_failure_fallback(mock_backend_cls, tmp_path): bad_file = tmp_path / "corrupt_ids.txt" bad_file.write_text("bad1\nbad2\n") @@ -186,9 +184,7 @@ def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): mock_col.count.side_effect = [10, 8] # Batch delete fails, per-id succeeds mock_col.delete.side_effect = [Exception("batch fail"), None, None] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + _install_mock_backend(mock_backend_cls, mock_col) repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) assert mock_col.delete.call_count == 3 # 1 batch + 2 individual @@ -197,29 +193,27 @@ def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): # ── rebuild_index ───────────────────────────────────────────────────── -@patch("mempalace.repair.chromadb") -def test_rebuild_index_no_palace(mock_chromadb, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_no_palace(mock_backend_cls, tmp_path): nonexistent = str(tmp_path / "nope") repair.rebuild_index(palace_path=nonexistent) - mock_chromadb.PersistentClient.assert_not_called() + mock_backend_cls.assert_not_called() @patch("mempalace.repair.shutil") -@patch("mempalace.repair.chromadb") -def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_empty_palace(mock_backend_cls, mock_shutil, tmp_path): mock_col = MagicMock() mock_col.count.return_value = 0 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) repair.rebuild_index(palace_path=str(tmp_path)) - mock_client.delete_collection.assert_not_called() + mock_backend.delete_collection.assert_not_called() @patch("mempalace.repair.shutil") -@patch("mempalace.repair.chromadb") -def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_success(mock_backend_cls, mock_shutil, tmp_path): # Create a fake sqlite file sqlite_path = tmp_path / "chroma.sqlite3" sqlite_path.write_text("fake") @@ -233,10 +227,8 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): } mock_new_col = MagicMock() - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - mock_client.create_collection.return_value = mock_new_col - mock_chromadb.PersistentClient.return_value = mock_client + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.return_value = mock_new_col repair.rebuild_index(palace_path=str(tmp_path)) @@ -244,11 +236,9 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): mock_shutil.copy2.assert_called_once() assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args) - # Verify: deleted and recreated with cosine - mock_client.delete_collection.assert_called_once_with("mempalace_drawers") - mock_client.create_collection.assert_called_once_with( - "mempalace_drawers", metadata={"hnsw:space": "cosine"} - ) + # Verify: deleted and recreated (cosine is the backend default) + mock_backend.delete_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers") + mock_backend.create_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers") # Verify: used upsert not add mock_new_col.upsert.assert_called_once() @@ -256,11 +246,11 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): @patch("mempalace.repair.shutil") -@patch("mempalace.repair.chromadb") -def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path): - mock_client = MagicMock() - mock_client.get_collection.side_effect = Exception("corrupt") - mock_chromadb.PersistentClient.return_value = mock_client +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_error_reading(mock_backend_cls, mock_shutil, tmp_path): + mock_backend = MagicMock() + mock_backend.get_collection.side_effect = Exception("corrupt") + mock_backend_cls.return_value = mock_backend repair.rebuild_index(palace_path=str(tmp_path)) - mock_client.delete_collection.assert_not_called() + mock_backend.delete_collection.assert_not_called() diff --git a/uv.lock b/uv.lock index 413f104..f9b6dca 100644 --- a/uv.lock +++ b/uv.lock @@ -1239,7 +1239,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "autocorrect", marker = "extra == 'spellcheck'", specifier = ">=2.0" }, - { name = "chromadb", specifier = ">=0.5.0,<0.7" }, + { name = "chromadb", specifier = ">=0.5.0" }, { name = "psutil", marker = "extra == 'dev'", specifier = ">=5.9" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },