267a644f4f
Prerequisite for RFC 001 (plugin spec, #743). Removes every direct `import chromadb` outside the ChromaDB backend itself so the core modules depend only on the backend abstraction layer. Extends ChromaBackend with make_client, get_or_create_collection, delete_collection, create_collection, and backend_version. Adds update() to the BaseCollection contract. Non-backend callers (mcp_server, dedup, repair, migrate, cli) now go through the abstraction; tests patch ChromaBackend instead of chromadb. With this landed, the RFC 001 spec can be enforced and PalaceStore (#643) can ship as a plugin without touching core modules.
153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
"""ChromaDB-backed MemPalace collection adapter."""
|
|
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
|
|
import chromadb
|
|
|
|
from .base import BaseCollection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _fix_blob_seq_ids(palace_path: str):
|
|
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
|
|
|
|
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
|
|
expects INTEGER. The auto-migration doesn't convert existing rows, causing
|
|
the Rust compactor to crash with "mismatched types; Rust type u64 (as SQL
|
|
type INTEGER) is not compatible with SQL type BLOB".
|
|
|
|
Must run BEFORE PersistentClient is created (the compactor fires on init).
|
|
"""
|
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
|
if not os.path.isfile(db_path):
|
|
return
|
|
try:
|
|
with sqlite3.connect(db_path) as conn:
|
|
for table in ("embeddings", "max_seq_id"):
|
|
try:
|
|
rows = conn.execute(
|
|
f"SELECT rowid, seq_id FROM {table} WHERE typeof(seq_id) = 'blob'"
|
|
).fetchall()
|
|
except sqlite3.OperationalError:
|
|
continue
|
|
if not rows:
|
|
continue
|
|
updates = [(int.from_bytes(blob, byteorder="big"), rowid) for rowid, blob in rows]
|
|
conn.executemany(f"UPDATE {table} SET seq_id = ? WHERE rowid = ?", updates)
|
|
logger.info("Fixed %d BLOB seq_ids in %s", len(updates), table)
|
|
conn.commit()
|
|
except Exception:
|
|
logger.exception("Could not fix BLOB seq_ids in %s", db_path)
|
|
|
|
|
|
class ChromaCollection(BaseCollection):
|
|
"""Thin adapter over a ChromaDB collection."""
|
|
|
|
def __init__(self, collection):
|
|
self._collection = collection
|
|
|
|
def add(self, *, documents, ids, metadatas=None):
|
|
self._collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
|
|
|
def upsert(self, *, documents, ids, metadatas=None):
|
|
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
|
|
|
|
def update(self, **kwargs):
|
|
self._collection.update(**kwargs)
|
|
|
|
def query(self, **kwargs):
|
|
return self._collection.query(**kwargs)
|
|
|
|
def get(self, **kwargs):
|
|
return self._collection.get(**kwargs)
|
|
|
|
def delete(self, **kwargs):
|
|
self._collection.delete(**kwargs)
|
|
|
|
def count(self):
|
|
return self._collection.count()
|
|
|
|
|
|
class ChromaBackend:
|
|
"""Factory for MemPalace's default ChromaDB backend."""
|
|
|
|
def __init__(self):
|
|
# Per-instance client cache: palace_path -> chromadb.PersistentClient
|
|
self._clients: dict = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _client(self, palace_path: str):
|
|
"""Return a cached PersistentClient for *palace_path*, creating one if needed."""
|
|
if palace_path not in self._clients:
|
|
_fix_blob_seq_ids(palace_path)
|
|
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path)
|
|
return self._clients[palace_path]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public static helpers (for callers that manage their own caching)
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def make_client(palace_path: str):
|
|
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first).
|
|
|
|
Intended for long-lived callers (e.g. mcp_server) that keep their own
|
|
inode/mtime-based client cache.
|
|
"""
|
|
_fix_blob_seq_ids(palace_path)
|
|
return chromadb.PersistentClient(path=palace_path)
|
|
|
|
@staticmethod
|
|
def backend_version() -> str:
|
|
"""Return the installed chromadb package version string."""
|
|
return chromadb.__version__
|
|
|
|
# ------------------------------------------------------------------
|
|
# Collection lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
|
|
if not create and not os.path.isdir(palace_path):
|
|
raise FileNotFoundError(palace_path)
|
|
|
|
if create:
|
|
os.makedirs(palace_path, exist_ok=True)
|
|
try:
|
|
os.chmod(palace_path, 0o700)
|
|
except (OSError, NotImplementedError):
|
|
pass
|
|
|
|
client = self._client(palace_path)
|
|
if create:
|
|
collection = client.get_or_create_collection(
|
|
collection_name, metadata={"hnsw:space": "cosine"}
|
|
)
|
|
else:
|
|
collection = client.get_collection(collection_name)
|
|
return ChromaCollection(collection)
|
|
|
|
def get_or_create_collection(
|
|
self, palace_path: str, collection_name: str
|
|
) -> "ChromaCollection":
|
|
"""Shorthand for get_collection(..., create=True)."""
|
|
return self.get_collection(palace_path, collection_name, create=True)
|
|
|
|
def delete_collection(self, palace_path: str, collection_name: str) -> None:
|
|
"""Delete *collection_name* from the palace at *palace_path*."""
|
|
self._client(palace_path).delete_collection(collection_name)
|
|
|
|
def create_collection(
|
|
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
|
|
) -> "ChromaCollection":
|
|
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
|
|
collection = self._client(palace_path).create_collection(
|
|
collection_name, metadata={"hnsw:space": hnsw_space}
|
|
)
|
|
return ChromaCollection(collection)
|