diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index 6cec804..e73ed41 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -4,7 +4,9 @@ import contextlib import datetime as _dt import logging import os +import pickle import sqlite3 +from numbers import Integral from pathlib import Path from typing import Any, Optional @@ -490,22 +492,17 @@ def hnsw_capacity_status(palace_path: str, collection_name: str = "mempalace_dra divergence_floor = max(_HNSW_DIVERGENCE_FALLBACK_FLOOR, 2 * sync_threshold) if hnsw_count is None: - # No pickle yet — segment hasn't persisted metadata. Could be - # fresh-but-unflushed (normal) or interrupted-mid-flush (bad). - # We can't distinguish without the pickle, so only flag - # divergence when sqlite holds clearly more than two flush - # windows worth — same threshold as the with-pickle path. - if sqlite_count > divergence_floor: - out["status"] = "diverged" - out["diverged"] = True - out["divergence"] = sqlite_count - out["message"] = ( - f"sqlite holds {sqlite_count:,} embeddings but the HNSW segment " - "has never flushed metadata — vector search will return nothing " - "until the segment is rebuilt. Run `mempalace repair`." - ) - else: - out["message"] = "HNSW segment metadata not yet flushed; skipping" + # No pickle yet, so this probe cannot measure HNSW capacity. + # Chroma 1.5.x can have binary HNSW files without a flushed + # metadata pickle; absence of the pickle alone is not proof that + # vector search is unusable or dangerous. Keep the status unknown + # so MCP does not globally disable vectors on an inconclusive + # signal. Corrupt/invalid metadata, when present, is handled by + # quarantine_invalid_hnsw_metadata before Chroma opens. + out["message"] = ( + "HNSW capacity unavailable: metadata has not been flushed; " + "leaving vector search enabled" + ) return out divergence = sqlite_count - hnsw_count @@ -592,6 +589,97 @@ def _pin_hnsw_threads(collection) -> None: _BLOB_FIX_MARKER = ".blob_seq_ids_migrated" +def _valid_dimensionality(value: object) -> bool: + return isinstance(value, Integral) and not isinstance(value, bool) and int(value) > 0 + + +def _persisted_metadata_fields(obj: object) -> tuple[object, object]: + if isinstance(obj, dict): + return obj.get("dimensionality"), obj.get("id_to_label") + return getattr(obj, "dimensionality", None), getattr(obj, "id_to_label", None) + + +def quarantine_invalid_hnsw_metadata(palace_path: str) -> list[str]: + """Quarantine segment dirs whose ``index_metadata.pickle`` is unreadable or invalid. + + Chroma's persisted HNSW metadata is untrusted disk state. If a segment has + labels but no valid positive dimensionality, current Chroma versions can + accept the pickle and crash later in the Rust loader. We rename the entire + segment out of the way before ``PersistentClient`` opens so Chroma can + rebuild cleanly instead of touching known-bad metadata. + """ + try: + entries = os.listdir(palace_path) + except OSError: + return [] + + moved: list[str] = [] + for name in entries: + if "-" not in name or name.startswith(".") or ".drift-" in name or ".corrupt-" in name: + continue + seg_dir = os.path.join(palace_path, name) + if not os.path.isdir(seg_dir): + continue + + meta_path = os.path.join(seg_dir, "index_metadata.pickle") + if not os.path.isfile(meta_path): + continue + + reason = None + try: + persisted = _SafePersistentDataUnpickler.load(meta_path) + except (EOFError, OSError): + logger.debug( + "Skipping invalid-HNSW quarantine for transient metadata read in %s", + meta_path, + exc_info=True, + ) + continue + except pickle.UnpicklingError as exc: + if "truncated" in str(exc).lower() or "ran out of input" in str(exc).lower(): + logger.debug( + "Skipping invalid-HNSW quarantine for transient metadata read in %s", + meta_path, + exc_info=True, + ) + continue + reason = f"invalid index_metadata.pickle: {exc}" + except Exception as exc: + reason = f"invalid index_metadata.pickle: {exc}" + else: + if not isinstance(persisted, dict) and not ( + hasattr(persisted, "dimensionality") or hasattr(persisted, "id_to_label") + ): + reason = f"unrecognized index_metadata.pickle payload: {type(persisted).__name__}" + else: + dimensionality, id_to_label = _persisted_metadata_fields(persisted) + if id_to_label is not None and not isinstance(id_to_label, dict): + reason = f"invalid id_to_label type {type(id_to_label).__name__}" + else: + has_labels = bool(id_to_label) + if has_labels and not _valid_dimensionality(dimensionality): + reason = ( + "labels present but dimensionality is missing or invalid " + f"({dimensionality!r})" + ) + elif dimensionality is not None and not _valid_dimensionality(dimensionality): + reason = f"invalid dimensionality {dimensionality!r}" + + if reason is None: + continue + + stamp = _dt.datetime.now().strftime("%Y%m%d-%H%M%S") + target = f"{seg_dir}.corrupt-{stamp}" + try: + os.rename(seg_dir, target) + moved.append(target) + logger.warning("Quarantined invalid HNSW metadata in %s: %s", seg_dir, reason) + except OSError: + logger.exception("Failed to quarantine invalid HNSW metadata in %s", seg_dir) + + return moved + + def _fix_blob_seq_ids(palace_path: str) -> None: """Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER. @@ -1045,6 +1133,13 @@ class ChromaBackend(BaseBackend): ) if cached is None or inode_changed or mtime_changed or mtime_appeared: + # An inode swap means we are reopening a different physical DB + # (post-restore, fresh palace at the same path, etc.); drop the + # per-process gate so the quarantine pre-checks run again + # against the new disk state instead of trusting cached "we + # already cleaned this path" credit from the prior inode. + if inode_changed: + ChromaBackend._quarantined_paths.discard(palace_path) ChromaBackend._prepare_palace_for_open(palace_path) cached = chromadb.PersistentClient(path=palace_path) self._clients[palace_path] = cached @@ -1058,26 +1153,27 @@ class ChromaBackend(BaseBackend): # Public static helpers (legacy; prefer :meth:`get_collection`) # ------------------------------------------------------------------ - # Per-process record of palaces that have already had quarantine_stale_hnsw - # invoked at least once. The proactive drift check is a *cold-start* - # protection — it catches HNSW segments that arrived stale relative to - # ``chroma.sqlite3`` (e.g. cross-machine replication, partial restore, - # crashed-mid-write). Once a long-running process has opened the palace - # cleanly, re-firing on every reconnect is a *runtime thrash*: the - # daemon's own writes bump sqlite mtime but HNSW flushes batch on - # chromadb's internal cadence, so the mtime gap naturally exceeds the - # threshold under steady write load even though nothing is corrupt. + # Per-process record of palaces that have already had the cold-start + # quarantine invoked at least once. The proactive HNSW checks are a + # *cold-start* protection — they catch segments that arrive stale relative + # to ``chroma.sqlite3`` or invalid on disk (e.g. cross-machine replication, + # partial restore, crashed-mid-write). Once a long-running process has + # opened the palace cleanly, re-firing the stale check on every reconnect + # is a *runtime thrash*: the daemon's own writes bump sqlite mtime but HNSW + # flushes batch on chromadb's internal cadence, so the mtime gap naturally + # exceeds the threshold under steady write load even though nothing is + # corrupt. # Real runtime drift is still handled — palace-daemon's ``_auto_repair`` # calls :func:`quarantine_stale_hnsw` directly on observed HNSW errors, # which bypasses this gate. # # Thread-safety: this set is mutated without a lock. Two concurrent # ``make_client()`` calls for the same palace can both pass the - # membership check and both invoke ``quarantine_stale_hnsw``. That's - # safe because the function is idempotent (mtime check + timestamped - # rename of distinct directories), so the worst-case race produces - # one redundant rename attempt that no-ops. Idempotency is the - # safety property; locking would add cost without correctness gain. + # membership check and both invoke the cold-start quarantine. That's + # safe because the functions are idempotent (mtime checks + timestamped + # rename of distinct directories), so the worst-case race produces one + # redundant rename attempt that no-ops. Idempotency is the safety + # property; locking would add cost without correctness gain. _quarantined_paths: set[str] = set() @staticmethod @@ -1085,12 +1181,16 @@ class ChromaBackend(BaseBackend): """Run the pre-open safety pass shared by :meth:`make_client` and :meth:`_client`. - Two steps, both required before constructing a ``PersistentClient``: + Three steps, all required before constructing a ``PersistentClient``: 1. ``_fix_blob_seq_ids`` — repairs the BLOB seq_id quirk that bites certain chromadb migrations. - 2. ``quarantine_stale_hnsw`` — gated by :attr:`_quarantined_paths` so - it fires once per palace per process. This is the SIGSEGV + 2. ``quarantine_invalid_hnsw_metadata`` — renames aside any HNSW + ``index_metadata.pickle`` that fails to load, so chromadb opens + against an empty index instead of crashing on the unloadable + pickle (#1266 / PR #1285). + 3. ``quarantine_stale_hnsw`` — also gated by :attr:`_quarantined_paths` + so it fires once per palace per process. This is the SIGSEGV prevention path for stale HNSW segments (see #1121, #1132, #1263); wiring it through this helper means CLI mining, search, repair, and status all benefit, not just the legacy ``make_client`` @@ -1102,6 +1202,7 @@ class ChromaBackend(BaseBackend): """ _fix_blob_seq_ids(palace_path) if palace_path not in ChromaBackend._quarantined_paths: + quarantine_invalid_hnsw_metadata(palace_path) quarantine_stale_hnsw(palace_path) ChromaBackend._quarantined_paths.add(palace_path) @@ -1113,7 +1214,7 @@ class ChromaBackend(BaseBackend): own client cache. New code should obtain a collection through :meth:`get_collection` which manages caching internally. - Quarantines stale HNSW segments **once per palace per process**. See + Quarantines HNSW segments **once per palace per process**. See :attr:`_quarantined_paths` for the rationale (cold-start protection vs. runtime thrash on steady-write daemons). """ diff --git a/mempalace/cli.py b/mempalace/cli.py index 95915eb..740de96 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -654,7 +654,14 @@ def cmd_repair(args): import shutil from .backends.chroma import ChromaBackend from .migrate import confirm_destructive_action, contains_palace_database - from .repair import TruncationDetected, check_extraction_safety + from .repair import ( + RebuildCollectionError, + TruncationDetected, + _close_chroma_handles, + _extract_drawers, + _rebuild_collection_via_temp, + check_extraction_safety, + ) palace_path = os.path.abspath( os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path @@ -762,18 +769,7 @@ def cmd_repair(args): # Extract all drawers in batches print("\n Extracting drawers...") batch_size = 5000 - all_ids = [] - all_docs = [] - all_metas = [] - offset = 0 - while offset < total: - batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) - if not batch["ids"]: - break - all_ids.extend(batch["ids"]) - all_docs.extend(batch["documents"]) - all_metas.extend(batch["metadatas"]) - offset += len(batch["ids"]) + all_ids, all_docs, all_metas = _extract_drawers(col, total, batch_size) print(f" Extracted {len(all_ids)} drawers") # ── #1208 guard ────────────────────────────────────────────────── @@ -793,7 +789,6 @@ def cmd_repair(args): print(e.message) return - # Backup and rebuild palace_path = os.path.normpath(palace_path) backup_path = palace_path + ".backup" if os.path.exists(backup_path): @@ -807,18 +802,33 @@ def cmd_repair(args): print(f" Backing up to {backup_path}...") shutil.copytree(palace_path, backup_path) - print(" Rebuilding collection...") - backend.delete_collection(palace_path, "mempalace_drawers") - new_col = backend.create_collection(palace_path, "mempalace_drawers") - - filed = 0 - for i in range(0, len(all_ids), batch_size): - batch_ids = all_ids[i : i + batch_size] - batch_docs = all_docs[i : i + batch_size] - batch_metas = all_metas[i : i + batch_size] - new_col.add(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) - filed += len(batch_ids) - print(f" Re-filed {filed}/{len(all_ids)} drawers...") + try: + filed = _rebuild_collection_via_temp( + backend, + palace_path, + all_ids, + all_docs, + all_metas, + batch_size, + progress=print, + ) + except RebuildCollectionError as e: + print(f" Repair failed: {e}") + if getattr(e, "live_replaced", False): + print(" Live collection was already replaced; restoring from backup...") + try: + _close_chroma_handles(palace_path, backend=backend) + if os.path.exists(palace_path): + shutil.rmtree(palace_path) + shutil.copytree(backup_path, palace_path) + print(f" Restore complete from backup: {backup_path}") + except Exception as restore_error: + print(f" Automatic restore failed: {restore_error}") + print(" Manual recovery required:") + print(f" 1. Remove or rename the broken directory: {palace_path}") + print(f" 2. Restore the backup directory to: {palace_path}") + print(f" Backup location: {backup_path}") + sys.exit(1) print(f"\n Repair complete. {filed} drawers rebuilt.") print(f" Backup saved at {backup_path}") diff --git a/mempalace/repair.py b/mempalace/repair.py index 90fb0cf..34d165c 100644 --- a/mempalace/repair.py +++ b/mempalace/repair.py @@ -38,10 +38,13 @@ from collections import defaultdict from datetime import datetime from typing import Iterator, Optional +from chromadb.errors import NotFoundError as ChromaNotFoundError + from .backends.chroma import ChromaBackend, hnsw_capacity_status COLLECTION_NAME = "mempalace_drawers" +REPAIR_TEMP_COLLECTION = f"{COLLECTION_NAME}__repair_tmp" # The closets collection (AAAK index layer) is intentionally fixed — # closets reference drawer IDs by string and live alongside drawers in the @@ -125,6 +128,108 @@ def _paginate_ids(col, where=None): return ids +def _extract_drawers(col, total: int, batch_size: int): + all_ids = [] + all_docs = [] + all_metas = [] + offset = 0 + while offset < total: + batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) + if not batch["ids"]: + break + all_ids.extend(batch["ids"]) + all_docs.extend(batch["documents"]) + all_metas.extend(batch["metadatas"]) + offset += len(batch["ids"]) + return all_ids, all_docs, all_metas + + +def _verify_collection_count(col, expected: int, label: str) -> None: + actual = col.count() + if actual != expected: + raise RuntimeError(f"{label} count mismatch: expected {expected}, got {actual}") + + +def _is_missing_collection_value_error(exc: ValueError) -> bool: + message = str(exc).lower() + return "does not exist" in message or "not found" in message + + +def _delete_collection_if_exists(backend, palace_path: str, collection_name: str) -> None: + try: + backend.delete_collection(palace_path, collection_name) + except ValueError as exc: + if _is_missing_collection_value_error(exc): + return + raise + except (FileNotFoundError, ChromaNotFoundError): + return + + +class RebuildCollectionError(RuntimeError): + """Raised when temp rebuild fails, carrying whether the live swap happened.""" + + def __init__(self, message: str, *, live_replaced: bool): + super().__init__(message) + self.live_replaced = live_replaced + + +def _rebuild_collection_via_temp( + backend, + palace_path: str, + all_ids, + all_docs, + all_metas, + batch_size: int, + progress=print, +) -> int: + expected = len(all_ids) + temp_name = REPAIR_TEMP_COLLECTION + live_replaced = False + + try: + _delete_collection_if_exists(backend, palace_path, temp_name) + + progress(f" Building temporary collection: {temp_name}") + temp_col = backend.create_collection(palace_path, temp_name) + staged = 0 + for i in range(0, expected, batch_size): + batch_ids = all_ids[i : i + batch_size] + batch_docs = all_docs[i : i + batch_size] + batch_metas = all_metas[i : i + batch_size] + temp_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) + staged += len(batch_ids) + progress(f" Staged {staged}/{expected} drawers...") + _verify_collection_count(temp_col, expected, "temporary rebuild") + + progress(" Rebuilding live collection...") + backend.delete_collection(palace_path, COLLECTION_NAME) + live_replaced = True + new_col = backend.create_collection(palace_path, COLLECTION_NAME) + + rebuilt = 0 + for i in range(0, expected, batch_size): + batch_ids = all_ids[i : i + batch_size] + batch_docs = all_docs[i : i + batch_size] + batch_metas = all_metas[i : i + batch_size] + new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) + rebuilt += len(batch_ids) + progress(f" Re-filed {rebuilt}/{expected} drawers...") + _verify_collection_count(new_col, expected, "rebuilt live collection") + + try: + _delete_collection_if_exists(backend, palace_path, temp_name) + except Exception: + pass + return rebuilt + except Exception as exc: + try: + _delete_collection_if_exists(backend, palace_path, temp_name) + except Exception: + pass + raise RebuildCollectionError(str(exc), live_replaced=live_replaced) from exc + + def scan_palace(palace_path=None, only_wing=None): """Scan the palace for corrupt/unfetchable IDs. @@ -415,18 +520,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): # Extract all drawers in batches print("\n Extracting drawers...") batch_size = 5000 - all_ids = [] - all_docs = [] - all_metas = [] - offset = 0 - while offset < total: - batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) - if not batch["ids"]: - break - all_ids.extend(batch["ids"]) - all_docs.extend(batch["documents"]) - all_metas.extend(batch["metadatas"]) - offset += len(batch["ids"]) + all_ids, all_docs, all_metas = _extract_drawers(col, total, batch_size) print(f" Extracted {len(all_ids)} drawers") # ── #1208 guard ────────────────────────────────────────────────── @@ -449,28 +543,33 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): # Rebuild with correct HNSW settings print(" Rebuilding collection with hnsw:space=cosine...") - backend.delete_collection(palace_path, COLLECTION_NAME) - new_col = backend.create_collection(palace_path, COLLECTION_NAME) - - filed = 0 try: - for i in range(0, len(all_ids), batch_size): - batch_ids = all_ids[i : i + batch_size] - batch_docs = all_docs[i : i + batch_size] - batch_metas = all_metas[i : i + batch_size] - new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) - filed += len(batch_ids) - print(f" Re-filed {filed}/{len(all_ids)} drawers...") - except Exception as e: + filed = _rebuild_collection_via_temp( + backend, + palace_path, + all_ids, + all_docs, + all_metas, + batch_size, + progress=print, + ) + except RebuildCollectionError as e: print(f"\n ERROR during rebuild: {e}") - print(f" Only {filed}/{len(all_ids)} drawers were re-filed.") - if os.path.exists(backup_path): + print(" Rebuild aborted before completion.") + if e.live_replaced and os.path.exists(backup_path): print(f" Restoring from backup: {backup_path}") - backend.delete_collection(palace_path, COLLECTION_NAME) - shutil.copy2(backup_path, sqlite_path) - print(" Backup restored. Palace is back to pre-repair state.") - else: + try: + _close_chroma_handles(palace_path, backend=backend) + _delete_collection_if_exists(backend, palace_path, COLLECTION_NAME) + shutil.copy2(backup_path, sqlite_path) + print(" Backup restored. Palace is back to pre-repair state.") + except Exception as restore_error: + print(f" Backup restore failed: {restore_error}") + print(f" Manual restore required from: {backup_path}") + elif e.live_replaced: print(" No backup available. Re-mine from source files to recover.") + else: + print(" Live collection was not replaced; leaving the original palace untouched.") raise print(f"\n Repair complete. {filed} drawers rebuilt.") @@ -909,12 +1008,18 @@ def status(palace_path=None) -> dict: # --------------------------------------------------------------------------- -def _close_chroma_handles(palace_path: str) -> None: - """Drop ChromaBackend + chromadb singleton caches so OS mmap handles release.""" +def _close_chroma_handles(palace_path: str, backend: "ChromaBackend | None" = None) -> None: + """Drop ChromaBackend + chromadb singleton caches so OS mmap handles release. + + When ``backend`` is provided, close the live instance so rollback/restore + releases the handles it was already using. Otherwise fall back to a + transient backend instance for the max-seq-id repair path. + """ import gc try: - ChromaBackend().close_palace(palace_path) + closer = backend if backend is not None else ChromaBackend() + closer.close_palace(palace_path) except Exception: pass try: diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 536610e..b318d99 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -406,6 +406,31 @@ def _bm25_only_via_sqlite( "hint": "Run: mempalace init && mempalace mine ", } + def _metadata_filter_sql(row_id_expr: str) -> tuple[str, list[str]]: + clauses = [] + params = [] + for key, value in (("wing", wing), ("room", room)): + if not value: + continue + clauses.append( + f""" + AND EXISTS ( + SELECT 1 + FROM embedding_metadata mf + WHERE mf.id = {row_id_expr} + AND mf.key = ? + AND COALESCE( + mf.string_value, + CAST(mf.int_value AS TEXT), + CAST(mf.float_value AS TEXT), + CAST(mf.bool_value AS TEXT) + ) = ? + ) + """ + ) + params.extend([key, value]) + return "".join(clauses), params + try: conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) except sqlite3.Error as e: @@ -418,15 +443,17 @@ def _bm25_only_via_sqlite( candidate_ids: list[int] = [] if tokens: fts_query = " OR ".join(tokens) + filter_sql, filter_params = _metadata_filter_sql("embedding_fulltext_search.rowid") try: rows = conn.execute( - """ + f""" SELECT rowid FROM embedding_fulltext_search WHERE embedding_fulltext_search MATCH ? + {filter_sql} LIMIT ? """, - (fts_query, max_candidates), + (fts_query, *filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: @@ -444,17 +471,19 @@ def _bm25_only_via_sqlite( # fall back to ordering by primary-key id and finally to an # empty result rather than letting search raise. try: + filter_sql, filter_params = _metadata_filter_sql("e.id") rows = conn.execute( - """ + f""" SELECT e.id FROM embeddings e JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id WHERE c.name = 'mempalace_drawers' + {filter_sql} ORDER BY e.created_at DESC LIMIT ? """, - (max_candidates,), + (*filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: @@ -463,17 +492,19 @@ def _bm25_only_via_sqlite( exc_info=True, ) try: + filter_sql, filter_params = _metadata_filter_sql("e.id") rows = conn.execute( - """ + f""" SELECT e.id FROM embeddings e JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id WHERE c.name = 'mempalace_drawers' + {filter_sql} ORDER BY e.id DESC LIMIT ? """, - (max_candidates,), + (*filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: diff --git a/tests/test_backends.py b/tests/test_backends.py index 06998c5..4cd9480 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,4 +1,5 @@ import os +import pickle import shutil import sqlite3 from pathlib import Path @@ -19,6 +20,7 @@ from mempalace.backends.chroma import ( ChromaCollection, _fix_blob_seq_ids, _pin_hnsw_threads, + quarantine_invalid_hnsw_metadata, quarantine_stale_hnsw, ) @@ -755,7 +757,10 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp """Quarantine fires on first ``make_client()`` for a palace, then is skipped on subsequent calls — prevents runtime thrash where a daemon's own steady writes bump ``chroma.sqlite3`` faster than HNSW flushes, - making the mtime heuristic falsely trigger every reconnect.""" + making the mtime heuristic falsely trigger every reconnect. + + Invalid metadata quarantine shares the same cold-start gate here; the + more aggressive refresh path lives in ``_client()``.""" from mempalace.backends.chroma import ChromaBackend palace_path = str(tmp_path / "palace") @@ -782,6 +787,34 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp ], "quarantine_stale_hnsw should fire once per palace per process, not on every reconnect" +def test_make_client_gates_invalid_metadata_on_first_call(tmp_path, monkeypatch): + """Invalid metadata quarantine is gated on the first make_client() call.""" + from mempalace.backends.chroma import ChromaBackend + + palace_path = str(tmp_path / "palace") + os.makedirs(palace_path, exist_ok=True) + (Path(palace_path) / "chroma.sqlite3").write_text("") + + monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set()) + + calls: list[str] = [] + + def _invalid(path, *args, **kwargs): + calls.append(path) + return [] + + def _stale(path, stale_seconds=300.0): + return [] + + monkeypatch.setattr("mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _invalid) + monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _stale) + + ChromaBackend.make_client(palace_path) + ChromaBackend.make_client(palace_path) + + assert calls == [palace_path] + + def test_make_client_quarantines_each_palace_independently(tmp_path, monkeypatch): """Two distinct palaces each get one quarantine attempt — the gate is keyed by palace path, not global.""" @@ -919,3 +952,245 @@ def test_get_collection_applies_retrofit_on_existing_palace(tmp_path): ) assert wrapper._collection.configuration_json["hnsw"]["num_threads"] == 1 + + +def test_quarantine_invalid_hnsw_metadata_renames_missing_dimensionality(tmp_path): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + with open(seg / "index_metadata.pickle", "wb") as f: + pickle.dump({"dimensionality": None, "id_to_label": {"a": 1}}, f) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert len(moved) == 1 + assert ".corrupt-" in moved[0] + assert not seg.exists() + + +def test_quarantine_invalid_hnsw_metadata_allows_uninitialized_segment(tmp_path): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + with open(seg / "index_metadata.pickle", "wb") as f: + pickle.dump({"dimensionality": None, "id_to_label": {}}, f) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert moved == [] + assert seg.exists() + + +def test_quarantine_invalid_hnsw_metadata_rejects_non_dict_id_to_label(tmp_path): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + with open(seg / "index_metadata.pickle", "wb") as f: + pickle.dump({"dimensionality": 8, "id_to_label": ["a", "b"]}, f) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert len(moved) == 1 + assert ".corrupt-" in moved[0] + assert not seg.exists() + + +def test_quarantine_invalid_hnsw_metadata_rejects_non_schema_payload(tmp_path): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + with open(seg / "index_metadata.pickle", "wb") as f: + pickle.dump(["not", "a", "metadata", "object"], f) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert len(moved) == 1 + assert ".corrupt-" in moved[0] + assert not seg.exists() + + +def _dangerous_pickle_payload_executed(): + raise AssertionError("unsafe pickle payload executed") + + +class _DangerousPickle: + def __reduce__(self): + return (_dangerous_pickle_payload_executed, ()) + + +def test_quarantine_invalid_hnsw_metadata_rejects_unsafe_pickle(tmp_path): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + with open(seg / "index_metadata.pickle", "wb") as f: + pickle.dump(_DangerousPickle(), f) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert len(moved) == 1 + assert ".corrupt-" in moved[0] + assert not seg.exists() + + +def test_quarantine_invalid_hnsw_metadata_skips_transient_read_errors(tmp_path, monkeypatch): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + meta = seg / "index_metadata.pickle" + meta.write_bytes(b"partial") + + monkeypatch.setattr( + "mempalace.backends.chroma._SafePersistentDataUnpickler.load", + lambda path: (_ for _ in ()).throw(EOFError("flush in progress")), + ) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert moved == [] + assert seg.exists() + + +def test_quarantine_invalid_hnsw_metadata_skips_truncated_pickle(tmp_path, monkeypatch): + palace = tmp_path / "palace" + palace.mkdir() + seg = palace / "abcd-1234-5678" + seg.mkdir() + meta = seg / "index_metadata.pickle" + meta.write_bytes(b"partial") + + monkeypatch.setattr( + "mempalace.backends.chroma._SafePersistentDataUnpickler.load", + lambda path: (_ for _ in ()).throw(pickle.UnpicklingError("pickle data was truncated")), + ) + + moved = quarantine_invalid_hnsw_metadata(str(palace)) + + assert moved == [] + assert seg.exists() + + +def test_chroma_backend_preflights_metadata_before_persistent_client(tmp_path, monkeypatch): + palace = tmp_path / "palace" + palace.mkdir() + calls = [] + + def _record(name): + def inner(path, *args, **kwargs): + calls.append((name, path)) + return [] if name != "blob" else None + + return inner + + monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob")) + monkeypatch.setattr( + "mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid") + ) + monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale")) + + class DummyClient: + pass + + monkeypatch.setattr( + "mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient() + ) + + backend = ChromaBackend() + backend._client(str(palace)) + + assert calls == [ + ("blob", str(palace)), + ("invalid", str(palace)), + ("stale", str(palace)), + ] + + +def test_chroma_backend_stale_quarantine_is_cold_start_only_on_refresh(tmp_path, monkeypatch): + palace = tmp_path / "palace" + palace.mkdir() + (palace / "chroma.sqlite3").write_text("") + calls = [] + + def _record(name): + def inner(path, *args, **kwargs): + calls.append((name, path)) + return [] if name != "blob" else None + + return inner + + monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set()) + monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob")) + monkeypatch.setattr( + "mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid") + ) + monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale")) + + class DummyClient: + pass + + monkeypatch.setattr( + "mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient() + ) + + backend = ChromaBackend() + stats = iter([(1, 1.0), (1, 1.0), (1, 2.0), (1, 2.0)]) + monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats)) + + backend._client(str(palace)) + backend._client(str(palace)) + + assert calls == [ + ("blob", str(palace)), + ("invalid", str(palace)), + ("stale", str(palace)), + ("blob", str(palace)), + ] + + +def test_chroma_backend_requarantines_after_inode_replacement(tmp_path, monkeypatch): + palace = tmp_path / "palace" + palace.mkdir() + (palace / "chroma.sqlite3").write_text("") + calls = [] + + def _record(name): + def inner(path, *args, **kwargs): + calls.append((name, path)) + return [] if name != "blob" else None + + return inner + + monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set()) + monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob")) + monkeypatch.setattr( + "mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid") + ) + monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale")) + + class DummyClient: + pass + + monkeypatch.setattr( + "mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient() + ) + + backend = ChromaBackend() + stats = iter([(1, 1.0), (1, 1.0), (2, 2.0), (2, 2.0)]) + monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats)) + + backend._client(str(palace)) + backend._client(str(palace)) + + assert calls == [ + ("blob", str(palace)), + ("invalid", str(palace)), + ("stale", str(palace)), + ("blob", str(palace)), + ("invalid", str(palace)), + ("stale", str(palace)), + ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 71ca63d..de00664 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,7 +4,7 @@ import argparse import shlex import sys from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest @@ -815,13 +815,58 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): "documents": ["doc1", "doc2"], "metadatas": [{"wing": "a"}, {"wing": "b"}], } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 mock_backend = _mock_backend_for(col=mock_col, new_col=mock_new_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): cmd_repair(args) out = capsys.readouterr().out assert "Repair complete" in out assert "2 drawers rebuilt" in out + assert mock_backend.delete_collection.call_args_list == [ + call(str(palace_dir), "mempalace_drawers__repair_tmp"), + call(str(palace_dir), "mempalace_drawers"), + call(str(palace_dir), "mempalace_drawers__repair_tmp"), + ] + mock_temp_col.upsert.assert_called_once() + mock_new_col.upsert.assert_called_once() + mock_new_col.add.assert_not_called() + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_restores_backup_on_live_rebuild_failure(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + (palace_dir / "chroma.sqlite3").write_text("db") + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None, yes=True) + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_backend = _mock_backend_for(col=mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, RuntimeError("live build failed")] + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): + with pytest.raises(SystemExit) as excinfo: + cmd_repair(args) + out = capsys.readouterr().out + assert excinfo.value.code == 1 + assert "Repair failed" in out + assert "restoring from backup" in out + mock_backend.close_palace.assert_called_once_with(str(palace_dir)) + assert mock_backend.delete_collection.call_args_list == [ + call(str(palace_dir), "mempalace_drawers__repair_tmp"), + call(str(palace_dir), "mempalace_drawers"), + call(str(palace_dir), "mempalace_drawers__repair_tmp"), + ] @patch("mempalace.cli.MempalaceConfig") diff --git a/tests/test_hnsw_capacity.py b/tests/test_hnsw_capacity.py index 512fc9c..912def8 100644 --- a/tests/test_hnsw_capacity.py +++ b/tests/test_hnsw_capacity.py @@ -238,14 +238,39 @@ def test_capacity_status_tolerates_flush_lag(tmp_path): assert info["status"] == "ok" -def test_capacity_status_flags_unflushed_with_large_sqlite(tmp_path): - """No pickle + many sqlite rows is its own divergence signal.""" +def test_capacity_status_does_not_flag_unflushed_with_large_sqlite(tmp_path): + """No pickle + many sqlite rows is inconclusive, not divergence.""" seg = "seg-noflush" _seed_chroma_db(str(tmp_path), sqlite_count=10_000, segment_id=seg) info = hnsw_capacity_status(str(tmp_path), COLLECTION) - assert info["diverged"] is True + assert info["diverged"] is False + assert info["status"] == "unknown" + assert info["divergence"] is None assert info["hnsw_count"] is None - assert "never flushed" in info["message"] + assert "capacity unavailable" in info["message"] + assert "leaving vector search enabled" in info["message"] + + +def test_mcp_probe_does_not_disable_vectors_for_unflushed_metadata(tmp_path, monkeypatch): + """The MCP preflight must not route all searches to BM25 on this signal.""" + from mempalace import mcp_server + + seg = "seg-mcp-noflush" + _seed_chroma_db(str(tmp_path), sqlite_count=10_000, segment_id=seg) + + class _Cfg: + palace_path = str(tmp_path) + + monkeypatch.setattr(mcp_server, "_config", _Cfg()) + monkeypatch.setattr(mcp_server, "_vector_disabled", True) + monkeypatch.setattr(mcp_server, "_vector_disabled_reason", "old divergence") + + mcp_server._refresh_vector_disabled_flag() + + assert mcp_server._vector_disabled is False + assert mcp_server._vector_disabled_reason == "" + assert mcp_server._vector_capacity_status["status"] == "unknown" + assert "leaving vector search enabled" in mcp_server._vector_capacity_status["message"] def test_capacity_status_quiet_for_empty_palace(tmp_path): @@ -372,6 +397,17 @@ def _seed_drawers(palace: str, segment_id: str, drawers: list[tuple[str, dict, s conn.close() +def _set_drawer_created_at(palace: str, timestamps: dict[int, str]) -> None: + db_path = os.path.join(palace, "chroma.sqlite3") + conn = sqlite3.connect(db_path) + try: + for emb_id, created_at in timestamps.items(): + conn.execute("UPDATE embeddings SET created_at = ? WHERE id = ?", (created_at, emb_id)) + conn.commit() + finally: + conn.close() + + @pytest.fixture def palace_with_drawers(tmp_path): seg = "seg-bm25" @@ -417,6 +453,122 @@ def test_bm25_fallback_filters_by_wing(palace_with_drawers): assert all(r["wing"] == "design" for r in out["results"]) +def test_bm25_fallback_applies_wing_before_fts_candidate_limit(tmp_path): + seg = "seg-bm25-fts-limit" + _seed_chroma_db(str(tmp_path), sqlite_count=0, segment_id=seg) + _seed_drawers( + str(tmp_path), + seg, + [ + ( + "shared token outside target wing", + {"wing": "ops", "room": "incidents", "source_file": "/x/ops.md"}, + "d-1", + ), + ( + "shared token inside target wing", + {"wing": "project", "room": "diary", "source_file": "/x/project.md"}, + "d-2", + ), + ], + ) + + out = _bm25_only_via_sqlite("shared token", str(tmp_path), wing="project", max_candidates=1) + + assert out["total_before_filter"] == 1 + assert len(out["results"]) == 1 + assert out["results"][0]["wing"] == "project" + + +def test_bm25_fallback_applies_room_before_fts_candidate_limit(tmp_path): + seg = "seg-bm25-room-limit" + _seed_chroma_db(str(tmp_path), sqlite_count=0, segment_id=seg) + _seed_drawers( + str(tmp_path), + seg, + [ + ( + "shared token wrong room", + {"wing": "project", "room": "scratch", "source_file": "/x/scratch.md"}, + "d-1", + ), + ( + "shared token right room", + {"wing": "project", "room": "diary", "source_file": "/x/diary.md"}, + "d-2", + ), + ], + ) + + out = _bm25_only_via_sqlite( + "shared token", + str(tmp_path), + wing="project", + room="diary", + max_candidates=1, + ) + + assert out["total_before_filter"] == 1 + assert len(out["results"]) == 1 + assert out["results"][0]["wing"] == "project" + assert out["results"][0]["room"] == "diary" + + +def test_bm25_fallback_applies_wing_before_recency_candidate_limit(tmp_path): + seg = "seg-bm25-recency-limit" + _seed_chroma_db(str(tmp_path), sqlite_count=0, segment_id=seg) + _seed_drawers( + str(tmp_path), + seg, + [ + ( + "target drawer for short query", + {"wing": "project", "room": "diary", "source_file": "/x/project.md"}, + "d-1", + ), + ( + "newer drawer outside target wing", + {"wing": "ops", "room": "incidents", "source_file": "/x/ops.md"}, + "d-2", + ), + ], + ) + _set_drawer_created_at( + str(tmp_path), + { + 1: "2026-01-01 00:00:00", + 2: "2026-02-01 00:00:00", + }, + ) + + out = _bm25_only_via_sqlite("a", str(tmp_path), wing="project", max_candidates=1) + + assert out["total_before_filter"] == 1 + assert len(out["results"]) == 1 + assert out["results"][0]["wing"] == "project" + + +def test_bm25_fallback_returns_empty_when_filtered_wing_has_no_candidates(tmp_path): + seg = "seg-bm25-empty-filter" + _seed_chroma_db(str(tmp_path), sqlite_count=0, segment_id=seg) + _seed_drawers( + str(tmp_path), + seg, + [ + ( + "shared token outside target wing", + {"wing": "ops", "room": "incidents", "source_file": "/x/ops.md"}, + "d-1", + ), + ], + ) + + out = _bm25_only_via_sqlite("shared token", str(tmp_path), wing="project", max_candidates=1) + + assert out["total_before_filter"] == 0 + assert out["results"] == [] + + def test_bm25_fallback_no_palace(tmp_path): out = _bm25_only_via_sqlite("anything", str(tmp_path)) assert "error" in out diff --git a/tests/test_repair.py b/tests/test_repair.py index 0277d9d..a60836a 100644 --- a/tests/test_repair.py +++ b/tests/test_repair.py @@ -2,7 +2,7 @@ import os import sqlite3 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest @@ -229,8 +229,11 @@ def test_rebuild_index_success(mock_backend_cls, mock_shutil, tmp_path): } mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 mock_backend = _install_mock_backend(mock_backend_cls, mock_col) - mock_backend.create_collection.return_value = mock_new_col + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] repair.rebuild_index(palace_path=str(tmp_path)) @@ -239,14 +242,74 @@ def test_rebuild_index_success(mock_backend_cls, mock_shutil, tmp_path): assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args) # Verify: deleted and recreated (cosine is the backend default) - mock_backend.delete_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers") - mock_backend.create_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers") + assert mock_backend.create_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + ] + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + ] # Verify: used upsert not add + mock_temp_col.upsert.assert_called_once() mock_new_col.upsert.assert_called_once() mock_new_col.add.assert_not_called() +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_ignores_missing_temp_collection_at_start( + mock_backend_cls, mock_shutil, tmp_path +): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + def _fake_copy2(src, dst): + with open(dst, "w") as handle: + handle.write("backup") + + mock_shutil.copy2.side_effect = _fake_copy2 + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + + mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + mock_backend.delete_collection.side_effect = [ + ValueError("Collection [mempalace_drawers__repair_tmp] does not exist"), + None, + None, + ] + + repair.rebuild_index(palace_path=str(tmp_path)) + + assert mock_shutil.copy2.call_count == 1 + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + ] + + +def test_delete_collection_if_exists_reraises_unexpected_value_error(): + mock_backend = MagicMock() + mock_backend.delete_collection.side_effect = ValueError("invalid collection name") + + with pytest.raises(ValueError, match="invalid collection name"): + repair._delete_collection_if_exists(mock_backend, "/palace", "bad/name") + + @patch("mempalace.repair.shutil") @patch("mempalace.repair.ChromaBackend") def test_rebuild_index_error_reading(mock_backend_cls, mock_shutil, tmp_path): @@ -365,19 +428,261 @@ def test_rebuild_index_proceeds_with_override(mock_backend_cls, mock_shutil, tmp }, {"ids": [], "documents": [], "metadatas": []}, ] + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 10_000 mock_new_col = MagicMock() + mock_new_col.count.return_value = 10_000 mock_backend.get_collection.return_value = mock_col - mock_backend.create_collection.return_value = mock_new_col + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] mock_backend_cls.return_value = mock_backend with patch("mempalace.repair.sqlite_drawer_count", return_value=67_580): repair.rebuild_index(palace_path=str(tmp_path), confirm_truncation_ok=True) - mock_backend.delete_collection.assert_called_once() - mock_backend.create_collection.assert_called_once() + assert mock_backend.delete_collection.call_count == 3 + assert mock_backend.create_collection.call_count == 2 + mock_temp_col.upsert.assert_called() mock_new_col.upsert.assert_called() +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_stage_failure_leaves_live_collection_untouched( + mock_backend_cls, mock_shutil, tmp_path +): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 1 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.return_value = mock_temp_col + + with pytest.raises(repair.RebuildCollectionError) as excinfo: + repair.rebuild_index(palace_path=str(tmp_path)) + + assert excinfo.value.live_replaced is False + assert mock_shutil.copy2.call_count == 1 + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + ] + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_live_failure_restores_backup(mock_backend_cls, mock_shutil, tmp_path): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + def _fake_copy2(src, dst): + with open(dst, "w") as handle: + handle.write("backup") + + mock_shutil.copy2.side_effect = _fake_copy2 + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_new_col = MagicMock() + mock_new_col.upsert.side_effect = RuntimeError("live upsert failed") + active_backend = MagicMock() + active_backend.get_collection.return_value = mock_col + active_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + helper_backend = MagicMock() + mock_backend_cls.side_effect = [active_backend, helper_backend] + + with pytest.raises(repair.RebuildCollectionError) as excinfo: + repair.rebuild_index(palace_path=str(tmp_path)) + + assert excinfo.value.live_replaced is True + assert mock_shutil.copy2.call_count == 2 + assert active_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + ] + active_backend.close_palace.assert_called_once_with(str(tmp_path)) + helper_backend.close_palace.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_live_delete_missing_still_restores_backup( + mock_backend_cls, mock_shutil, tmp_path +): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + def _fake_copy2(src, dst): + with open(dst, "w") as handle: + handle.write("backup") + + mock_shutil.copy2.side_effect = _fake_copy2 + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, RuntimeError("create failed")] + mock_backend.delete_collection.side_effect = [ + None, + None, + None, + repair.ChromaNotFoundError("missing"), + ] + + with pytest.raises(repair.RebuildCollectionError) as excinfo: + repair.rebuild_index(palace_path=str(tmp_path)) + + assert excinfo.value.live_replaced is True + assert mock_shutil.copy2.call_count == 2 + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + ] + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_restore_failure_preserves_original_error( + mock_backend_cls, mock_shutil, tmp_path, capsys +): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + def _copy2_side_effect(src, dst): + if str(src).endswith(".backup"): + raise PermissionError("locked sqlite") + with open(dst, "w") as handle: + handle.write("backup") + + mock_shutil.copy2.side_effect = _copy2_side_effect + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_new_col = MagicMock() + mock_new_col.upsert.side_effect = RuntimeError("live upsert failed") + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + + with pytest.raises(repair.RebuildCollectionError) as excinfo: + repair.rebuild_index(palace_path=str(tmp_path)) + + out = capsys.readouterr().out + assert "locked sqlite" in out + assert "Manual restore required" in out + assert "live upsert failed" in str(excinfo.value) + + +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_collection_via_temp_keeps_original_error_when_cleanup_fails( + mock_backend_cls, +): + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, RuntimeError("live build failed")] + mock_backend.delete_collection.side_effect = [ + None, + None, + RuntimeError("cleanup failed"), + ] + + with pytest.raises(repair.RebuildCollectionError) as excinfo: + repair._rebuild_collection_via_temp( + mock_backend, + "/palace", + ["id1", "id2"], + ["doc1", "doc2"], + [{"wing": "a"}, {"wing": "b"}], + batch_size=5000, + progress=lambda *args, **kwargs: None, + ) + + assert "live build failed" in str(excinfo.value) + assert excinfo.value.live_replaced is True + assert mock_backend.delete_collection.call_args_list == [ + call("/palace", "mempalace_drawers__repair_tmp"), + call("/palace", "mempalace_drawers"), + call("/palace", "mempalace_drawers__repair_tmp"), + ] + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_ignores_temp_cleanup_failure_after_success( + mock_backend_cls, mock_shutil, tmp_path +): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + def _fake_copy2(src, dst): + with open(dst, "w") as handle: + handle.write("backup") + + mock_shutil.copy2.side_effect = _fake_copy2 + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + mock_backend.delete_collection.side_effect = [ + None, + None, + RuntimeError("cleanup failed"), + ] + + repair.rebuild_index(palace_path=str(tmp_path)) + + assert mock_shutil.copy2.call_count == 1 + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + call(str(tmp_path), "mempalace_drawers"), + call(str(tmp_path), "mempalace_drawers__repair_tmp"), + ] + + # ── repair_max_seq_id ─────────────────────────────────────────────────