fix(chroma): harden HNSW startup preflight

This commit is contained in:
Mika Cohen
2026-04-30 09:31:32 -06:00
parent 5e540da06b
commit c3e1104e75
2 changed files with 395 additions and 19 deletions
+116 -15
View File
@@ -3,7 +3,9 @@
import datetime as _dt
import logging
import os
import pickle
import sqlite3
from numbers import Integral
from pathlib import Path
from typing import Any, Optional
@@ -591,6 +593,97 @@ def _pin_hnsw_threads(collection) -> None:
_BLOB_FIX_MARKER = ".blob_seq_ids_migrated"
def _valid_dimensionality(value: object) -> bool:
return isinstance(value, Integral) and not isinstance(value, bool) and int(value) > 0
def _persisted_metadata_fields(obj: object) -> tuple[object, object]:
if isinstance(obj, dict):
return obj.get("dimensionality"), obj.get("id_to_label")
return getattr(obj, "dimensionality", None), getattr(obj, "id_to_label", None)
def quarantine_invalid_hnsw_metadata(palace_path: str) -> list[str]:
"""Quarantine segment dirs whose ``index_metadata.pickle`` is unreadable or invalid.
Chroma's persisted HNSW metadata is untrusted disk state. If a segment has
labels but no valid positive dimensionality, current Chroma versions can
accept the pickle and crash later in the Rust loader. We rename the entire
segment out of the way before ``PersistentClient`` opens so Chroma can
rebuild cleanly instead of touching known-bad metadata.
"""
try:
entries = os.listdir(palace_path)
except OSError:
return []
moved: list[str] = []
for name in entries:
if "-" not in name or name.startswith(".") or ".drift-" in name or ".corrupt-" in name:
continue
seg_dir = os.path.join(palace_path, name)
if not os.path.isdir(seg_dir):
continue
meta_path = os.path.join(seg_dir, "index_metadata.pickle")
if not os.path.isfile(meta_path):
continue
reason = None
try:
persisted = _SafePersistentDataUnpickler.load(meta_path)
except (EOFError, OSError):
logger.debug(
"Skipping invalid-HNSW quarantine for transient metadata read in %s",
meta_path,
exc_info=True,
)
continue
except pickle.UnpicklingError as exc:
if "truncated" in str(exc).lower() or "ran out of input" in str(exc).lower():
logger.debug(
"Skipping invalid-HNSW quarantine for transient metadata read in %s",
meta_path,
exc_info=True,
)
continue
reason = f"invalid index_metadata.pickle: {exc}"
except Exception as exc:
reason = f"invalid index_metadata.pickle: {exc}"
else:
if not isinstance(persisted, dict) and not (
hasattr(persisted, "dimensionality") or hasattr(persisted, "id_to_label")
):
reason = f"unrecognized index_metadata.pickle payload: {type(persisted).__name__}"
else:
dimensionality, id_to_label = _persisted_metadata_fields(persisted)
if id_to_label is not None and not isinstance(id_to_label, dict):
reason = f"invalid id_to_label type {type(id_to_label).__name__}"
else:
has_labels = bool(id_to_label)
if has_labels and not _valid_dimensionality(dimensionality):
reason = (
"labels present but dimensionality is missing or invalid "
f"({dimensionality!r})"
)
elif dimensionality is not None and not _valid_dimensionality(dimensionality):
reason = f"invalid dimensionality {dimensionality!r}"
if reason is None:
continue
stamp = _dt.datetime.now().strftime("%Y%m%d-%H%M%S")
target = f"{seg_dir}.corrupt-{stamp}"
try:
os.rename(seg_dir, target)
moved.append(target)
logger.warning("Quarantined invalid HNSW metadata in %s: %s", seg_dir, reason)
except OSError:
logger.exception("Failed to quarantine invalid HNSW metadata in %s", seg_dir)
return moved
def _fix_blob_seq_ids(palace_path: str) -> None:
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
@@ -994,6 +1087,12 @@ class ChromaBackend(BaseBackend):
if cached is None or inode_changed or mtime_changed or mtime_appeared:
_fix_blob_seq_ids(palace_path)
if inode_changed:
ChromaBackend._quarantined_paths.discard(palace_path)
if palace_path not in ChromaBackend._quarantined_paths:
quarantine_invalid_hnsw_metadata(palace_path)
quarantine_stale_hnsw(palace_path)
ChromaBackend._quarantined_paths.add(palace_path)
cached = chromadb.PersistentClient(path=palace_path)
self._clients[palace_path] = cached
# Re-stat after the client constructor runs: chromadb creates
@@ -1006,26 +1105,27 @@ class ChromaBackend(BaseBackend):
# Public static helpers (legacy; prefer :meth:`get_collection`)
# ------------------------------------------------------------------
# Per-process record of palaces that have already had quarantine_stale_hnsw
# invoked at least once. The proactive drift check is a *cold-start*
# protection — it catches HNSW segments that arrived stale relative to
# ``chroma.sqlite3`` (e.g. cross-machine replication, partial restore,
# crashed-mid-write). Once a long-running process has opened the palace
# cleanly, re-firing on every reconnect is a *runtime thrash*: the
# daemon's own writes bump sqlite mtime but HNSW flushes batch on
# chromadb's internal cadence, so the mtime gap naturally exceeds the
# threshold under steady write load even though nothing is corrupt.
# Per-process record of palaces that have already had the cold-start
# quarantine invoked at least once. The proactive HNSW checks are a
# *cold-start* protection — they catch segments that arrive stale relative
# to ``chroma.sqlite3`` or invalid on disk (e.g. cross-machine replication,
# partial restore, crashed-mid-write). Once a long-running process has
# opened the palace cleanly, re-firing the stale check on every reconnect
# is a *runtime thrash*: the daemon's own writes bump sqlite mtime but HNSW
# flushes batch on chromadb's internal cadence, so the mtime gap naturally
# exceeds the threshold under steady write load even though nothing is
# corrupt.
# Real runtime drift is still handled — palace-daemon's ``_auto_repair``
# calls :func:`quarantine_stale_hnsw` directly on observed HNSW errors,
# which bypasses this gate.
#
# Thread-safety: this set is mutated without a lock. Two concurrent
# ``make_client()`` calls for the same palace can both pass the
# membership check and both invoke ``quarantine_stale_hnsw``. That's
# safe because the function is idempotent (mtime check + timestamped
# rename of distinct directories), so the worst-case race produces
# one redundant rename attempt that no-ops. Idempotency is the
# safety property; locking would add cost without correctness gain.
# membership check and both invoke the cold-start quarantine. That's
# safe because the functions are idempotent (mtime checks + timestamped
# rename of distinct directories), so the worst-case race produces one
# redundant rename attempt that no-ops. Idempotency is the safety
# property; locking would add cost without correctness gain.
_quarantined_paths: set[str] = set()
@staticmethod
@@ -1036,12 +1136,13 @@ class ChromaBackend(BaseBackend):
own client cache. New code should obtain a collection through
:meth:`get_collection` which manages caching internally.
Quarantines stale HNSW segments **once per palace per process**. See
Quarantines HNSW segments **once per palace per process**. See
:attr:`_quarantined_paths` for the rationale (cold-start protection
vs. runtime thrash on steady-write daemons).
"""
_fix_blob_seq_ids(palace_path)
if palace_path not in ChromaBackend._quarantined_paths:
quarantine_invalid_hnsw_metadata(palace_path)
quarantine_stale_hnsw(palace_path)
ChromaBackend._quarantined_paths.add(palace_path)
return chromadb.PersistentClient(path=palace_path)
+279 -4
View File
@@ -1,4 +1,5 @@
import os
import pickle
import sqlite3
from pathlib import Path
@@ -18,6 +19,7 @@ from mempalace.backends.chroma import (
ChromaCollection,
_fix_blob_seq_ids,
_pin_hnsw_threads,
quarantine_invalid_hnsw_metadata,
quarantine_stale_hnsw,
)
@@ -708,7 +710,10 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp
"""Quarantine fires on first ``make_client()`` for a palace, then is
skipped on subsequent calls — prevents runtime thrash where a daemon's
own steady writes bump ``chroma.sqlite3`` faster than HNSW flushes,
making the mtime heuristic falsely trigger every reconnect."""
making the mtime heuristic falsely trigger every reconnect.
Invalid metadata quarantine shares the same cold-start gate here; the
more aggressive refresh path lives in ``_client()``."""
from mempalace.backends.chroma import ChromaBackend
palace_path = str(tmp_path / "palace")
@@ -730,9 +735,37 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp
ChromaBackend.make_client(palace_path)
ChromaBackend.make_client(palace_path)
assert calls == [
palace_path
], "quarantine_stale_hnsw should fire once per palace per process, not on every reconnect"
assert calls == [palace_path], (
"quarantine_stale_hnsw should fire once per palace per process, not on every reconnect"
)
def test_make_client_gates_invalid_metadata_on_first_call(tmp_path, monkeypatch):
"""Invalid metadata quarantine is gated on the first make_client() call."""
from mempalace.backends.chroma import ChromaBackend
palace_path = str(tmp_path / "palace")
os.makedirs(palace_path, exist_ok=True)
(Path(palace_path) / "chroma.sqlite3").write_text("")
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
calls: list[str] = []
def _invalid(path, *args, **kwargs):
calls.append(path)
return []
def _stale(path, stale_seconds=300.0):
return []
monkeypatch.setattr("mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _invalid)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _stale)
ChromaBackend.make_client(palace_path)
ChromaBackend.make_client(palace_path)
assert calls == [palace_path]
def test_make_client_quarantines_each_palace_independently(tmp_path, monkeypatch):
@@ -811,3 +844,245 @@ def test_get_collection_applies_retrofit_on_existing_palace(tmp_path):
)
assert wrapper._collection.configuration_json["hnsw"]["num_threads"] == 1
def test_quarantine_invalid_hnsw_metadata_renames_missing_dimensionality(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": None, "id_to_label": {"a": 1}}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_allows_uninitialized_segment(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": None, "id_to_label": {}}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_quarantine_invalid_hnsw_metadata_rejects_non_dict_id_to_label(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": 8, "id_to_label": ["a", "b"]}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_rejects_non_schema_payload(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump(["not", "a", "metadata", "object"], f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def _dangerous_pickle_payload_executed():
raise AssertionError("unsafe pickle payload executed")
class _DangerousPickle:
def __reduce__(self):
return (_dangerous_pickle_payload_executed, ())
def test_quarantine_invalid_hnsw_metadata_rejects_unsafe_pickle(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump(_DangerousPickle(), f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_skips_transient_read_errors(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
meta = seg / "index_metadata.pickle"
meta.write_bytes(b"partial")
monkeypatch.setattr(
"mempalace.backends.chroma._SafePersistentDataUnpickler.load",
lambda path: (_ for _ in ()).throw(EOFError("flush in progress")),
)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_quarantine_invalid_hnsw_metadata_skips_truncated_pickle(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
meta = seg / "index_metadata.pickle"
meta.write_bytes(b"partial")
monkeypatch.setattr(
"mempalace.backends.chroma._SafePersistentDataUnpickler.load",
lambda path: (_ for _ in ()).throw(pickle.UnpicklingError("pickle data was truncated")),
)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_chroma_backend_preflights_metadata_before_persistent_client(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
]
def test_chroma_backend_stale_quarantine_is_cold_start_only_on_refresh(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
(palace / "chroma.sqlite3").write_text("")
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
stats = iter([(1, 1.0), (1, 1.0), (1, 2.0), (1, 2.0)])
monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats))
backend._client(str(palace))
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
("blob", str(palace)),
]
def test_chroma_backend_requarantines_after_inode_replacement(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
(palace / "chroma.sqlite3").write_text("")
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
stats = iter([(1, 1.0), (1, 1.0), (2, 2.0), (2, 2.0)])
monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats))
backend._client(str(palace))
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
]