diff --git a/mempalace/backends/__init__.py b/mempalace/backends/__init__.py index cb5f14d..ab22ec6 100644 --- a/mempalace/backends/__init__.py +++ b/mempalace/backends/__init__.py @@ -1,6 +1,64 @@ -"""Storage backend implementations for MemPalace.""" +"""Storage backend implementations for MemPalace (RFC 001). -from .base import BaseCollection +Public surface: + +* :class:`BaseCollection` — per-collection read/write contract. +* :class:`BaseBackend` — per-palace factory contract. +* :class:`PalaceRef` — value object identifying a palace for a backend. +* :class:`QueryResult` / :class:`GetResult` — typed read returns. +* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`, + :class:`UnsupportedFilterError`, :class:`DimensionMismatchError`, + :class:`EmbedderIdentityMismatchError`. +* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`, + :func:`resolve_backend_for_palace`. +* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`. +""" + +from .base import ( + BackendClosedError, + BackendError, + BaseBackend, + BaseCollection, + DimensionMismatchError, + EmbedderIdentityMismatchError, + GetResult, + HealthStatus, + PalaceNotFoundError, + PalaceRef, + QueryResult, + UnsupportedFilterError, +) from .chroma import ChromaBackend, ChromaCollection +from .registry import ( + available_backends, + get_backend, + get_backend_class, + register, + reset_backends, + resolve_backend_for_palace, + unregister, +) -__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"] +__all__ = [ + "BackendClosedError", + "BackendError", + "BaseBackend", + "BaseCollection", + "ChromaBackend", + "ChromaCollection", + "DimensionMismatchError", + "EmbedderIdentityMismatchError", + "GetResult", + "HealthStatus", + "PalaceNotFoundError", + "PalaceRef", + "QueryResult", + "UnsupportedFilterError", + "available_backends", + "get_backend", + "get_backend_class", + "register", + "reset_backends", + "resolve_backend_for_palace", + "unregister", +] diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py index 877da53..819d326 100644 --- a/mempalace/backends/base.py +++ b/mempalace/backends/base.py @@ -1,49 +1,354 @@ -"""Abstract collection interface for MemPalace storage backends.""" +"""Storage backend contract for MemPalace (RFC 001). + +This module defines the surface every storage backend must implement: + +* ``BaseCollection`` — the per-collection read/write interface, kwargs-only. +* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``. +* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the + Chroma dict shape as the canonical return type. +* Error classes + ``HealthStatus`` — uniform across backends. + +This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``, +registry-ready ABC. Embedder injection, maintenance hooks, and the full +conformance suite land in follow-up PRs. +""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import ClassVar, Optional + + +# --------------------------------------------------------------------------- +# Errors +# --------------------------------------------------------------------------- + + +class BackendError(Exception): + """Base class for every storage-backend error raised by core.""" + + +class PalaceNotFoundError(BackendError, FileNotFoundError): + """Raised when ``get_collection(create=False)`` is called on a missing palace. + + Subclass of ``FileNotFoundError`` so legacy callers that catch the latter + (pre-#413 seam) keep working unchanged. + """ + + +class BackendClosedError(BackendError): + """Raised when a backend method is called after ``close()``.""" + + +class UnsupportedFilterError(BackendError): + """Raised when a where-clause uses an operator the backend does not implement. + + Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4). + """ + + +class DimensionMismatchError(BackendError): + """Raised when the embedding dimension on write does not match the collection.""" + + +class EmbedderIdentityMismatchError(BackendError): + """Raised when the stored embedder model name differs from the current one.""" + + +# --------------------------------------------------------------------------- +# Value objects +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class PalaceRef: + """A handle to a palace, consumed by backends. + + ``id`` is always present and is the key backends use to cache handles. + ``local_path`` is populated for filesystem-rooted palaces. + ``namespace`` is used by server-mode backends for tenant / prefix routing. + """ + + id: str + local_path: Optional[str] = None + namespace: Optional[str] = None + + +@dataclass(frozen=True) +class HealthStatus: + ok: bool + detail: str = "" + + @classmethod + def healthy(cls, detail: str = "") -> "HealthStatus": + return cls(ok=True, detail=detail) + + @classmethod + def unhealthy(cls, detail: str) -> "HealthStatus": + return cls(ok=False, detail=detail) + + +_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings") + + +class _DictCompatMixin: + """Transitional dict-protocol access for typed results. + + RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]`` + and ``result.get("ids")`` forms are retained as a migration shim for callers + that predate the typed interface and are scheduled for removal in a follow- + up cleanup. New code MUST use attribute access. + """ + + def __getitem__(self, key: str): + if key in _TYPED_RESULT_FIELDS: + return getattr(self, key) + raise KeyError(key) + + def get(self, key: str, default=None): + if key in _TYPED_RESULT_FIELDS: + val = getattr(self, key, default) + return default if val is None else val + return default + + def __contains__(self, key: object) -> bool: + return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None + + +@dataclass(frozen=True) +class QueryResult(_DictCompatMixin): + """Typed return from ``BaseCollection.query``. + + Outer list dimension = number of query vectors / texts. + Inner list dimension = hits per query (may be zero). + + Fields not in ``include=`` at the call site are populated with empty lists + of the correct outer shape (never ``None``), except ``embeddings`` which + is ``None`` when not requested. + """ + + ids: list[list[str]] + documents: list[list[str]] + metadatas: list[list[dict]] + distances: list[list[float]] + embeddings: Optional[list[list[list[float]]]] = None + + @classmethod + def empty(cls, num_queries: int = 1) -> "QueryResult": + """Construct an all-empty result preserving outer dimension.""" + return cls( + ids=[[] for _ in range(num_queries)], + documents=[[] for _ in range(num_queries)], + metadatas=[[] for _ in range(num_queries)], + distances=[[] for _ in range(num_queries)], + embeddings=None, + ) + + +@dataclass(frozen=True) +class GetResult(_DictCompatMixin): + """Typed return from ``BaseCollection.get``.""" + + ids: list[str] + documents: list[str] + metadatas: list[dict] + embeddings: Optional[list[list[float]]] = None + + @classmethod + def empty(cls) -> "GetResult": + return cls(ids=[], documents=[], metadatas=[], embeddings=None) + + +# --------------------------------------------------------------------------- +# Collection contract +# --------------------------------------------------------------------------- class BaseCollection(ABC): - """Smallest collection contract the rest of MemPalace relies on.""" + """Per-collection read/write surface every backend must implement.""" @abstractmethod def add( self, *, - documents: List[str], - ids: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, - ) -> None: - raise NotImplementedError + documents: list[str], + ids: list[str], + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, + ) -> None: ... @abstractmethod def upsert( self, *, - documents: List[str], - ids: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, + documents: list[str], + ids: list[str], + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, + ) -> None: ... + + @abstractmethod + def query( + self, + *, + query_texts: Optional[list[str]] = None, + query_embeddings: Optional[list[list[float]]] = None, + n_results: int = 10, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + ) -> QueryResult: ... + + @abstractmethod + def get( + self, + *, + ids: Optional[list[str]] = None, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + include: Optional[list[str]] = None, + ) -> GetResult: ... + + @abstractmethod + def delete( + self, + *, + ids: Optional[list[str]] = None, + where: Optional[dict] = None, + ) -> None: ... + + @abstractmethod + def count(self) -> int: ... + + # ------------------------------------------------------------------ + # Optional methods with ABC defaults (spec §1.2) + # ------------------------------------------------------------------ + + def estimated_count(self) -> int: + return self.count() + + def close(self) -> None: + return None + + def health(self) -> HealthStatus: + return HealthStatus.healthy() + + def update( + self, + *, + ids: list[str], + documents: Optional[list[str]] = None, + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, ) -> None: - raise NotImplementedError + """Default non-atomic update: get + merge + upsert. + + Backends advertising ``supports_update`` MUST override with an atomic + single-round-trip implementation. + """ + if documents is None and metadatas is None and embeddings is None: + raise ValueError("update requires at least one of documents, metadatas, embeddings") + + existing = self.get(ids=ids, include=["documents", "metadatas"]) + by_id = { + rid: (existing.documents[i], existing.metadatas[i]) + for i, rid in enumerate(existing.ids) + } + merged_docs: list[str] = [] + merged_metas: list[dict] = [] + for i, rid in enumerate(ids): + prev_doc, prev_meta = by_id.get(rid, ("", {})) + merged_docs.append(documents[i] if documents is not None else prev_doc) + new_meta = dict(prev_meta or {}) + if metadatas is not None: + new_meta.update(metadatas[i] or {}) + merged_metas.append(new_meta) + self.upsert( + documents=merged_docs, + ids=list(ids), + metadatas=merged_metas, + embeddings=embeddings, + ) + + +# --------------------------------------------------------------------------- +# Backend contract +# --------------------------------------------------------------------------- + + +class BaseBackend(ABC): + """Long-lived factory serving many palaces (RFC 001 §2). + + Instances are lightweight on construction — no I/O, no network. All + connection work is deferred to ``get_collection``. Instances are thread- + safe for concurrent ``get_collection`` calls across different palaces. + """ + + name: ClassVar[str] + spec_version: ClassVar[str] = "1.0" + capabilities: ClassVar[frozenset[str]] = frozenset() @abstractmethod - def update(self, **kwargs: Any) -> None: - """Update existing records. Must raise if any ID is missing.""" - raise NotImplementedError + def get_collection( + self, + *, + palace: PalaceRef, + collection_name: str, + create: bool = False, + options: Optional[dict] = None, + ) -> BaseCollection: ... - @abstractmethod - def query(self, **kwargs: Any) -> Dict[str, Any]: - raise NotImplementedError + def close_palace(self, palace: PalaceRef) -> None: + """Evict cached handles for a single palace. Default: no-op.""" + return None - @abstractmethod - def get(self, **kwargs: Any) -> Dict[str, Any]: - raise NotImplementedError + def close(self) -> None: + """Shut down the entire backend. Default: no-op.""" + return None - @abstractmethod - def delete(self, **kwargs: Any) -> None: - raise NotImplementedError + def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus: + return HealthStatus.healthy() - @abstractmethod - def count(self) -> int: - raise NotImplementedError + # Optional detection hint used by selection priority (RFC 001 §3.3 (4)): + @classmethod + def detect(cls, path: str) -> bool: # pragma: no cover - default hook + return False + + +# --------------------------------------------------------------------------- +# Adapter utilities +# --------------------------------------------------------------------------- + + +# Keys the Chroma ``include=`` parameter accepts. +_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"}) + + +@dataclass +class _IncludeSpec: + """Resolve an ``include=`` parameter with spec-mandated defaults.""" + + documents: bool = True + metadatas: bool = True + distances: bool = True # only meaningful for query + embeddings: bool = False + + @classmethod + def resolve( + cls, include: Optional[list[str]], *, default_distances: bool = True + ) -> "_IncludeSpec": + if include is None: + return cls( + documents=True, + metadatas=True, + distances=default_distances, + embeddings=False, + ) + keys = {k for k in include if k in _VALID_INCLUDE_KEYS} + return cls( + documents="documents" in keys, + metadatas="metadatas" in keys, + distances="distances" in keys, + embeddings="embeddings" in keys, + ) diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index 1a13675..f12a88b 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -1,17 +1,54 @@ -"""ChromaDB-backed MemPalace collection adapter.""" +"""ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation).""" import logging import os import sqlite3 +from typing import Any, Optional import chromadb -from .base import BaseCollection +from .base import ( + BaseBackend, + BaseCollection, + GetResult, + HealthStatus, + PalaceNotFoundError, + PalaceRef, + QueryResult, + UnsupportedFilterError, + _IncludeSpec, +) logger = logging.getLogger(__name__) -def _fix_blob_seq_ids(palace_path: str): +_REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"}) +_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"}) +_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS + + +def _validate_where(where: Optional[dict]) -> None: + """Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``. + + Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden. + """ + if not where: + return + stack = [where] + while stack: + node = stack.pop() + if not isinstance(node, dict): + continue + for k, v in node.items(): + if k.startswith("$") and k not in _SUPPORTED_OPERATORS: + raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend") + if isinstance(v, dict): + stack.append(v) + elif isinstance(v, list): + stack.extend(x for x in v if isinstance(x, dict)) + + +def _fix_blob_seq_ids(palace_path: str) -> None: """Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER. ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x @@ -43,62 +80,293 @@ def _fix_blob_seq_ids(palace_path: str): logger.exception("Could not fix BLOB seq_ids in %s", db_path) +# --------------------------------------------------------------------------- +# Collection adapter +# --------------------------------------------------------------------------- + + +def _as_list(v: Any) -> list: + """Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls).""" + if v is None: + return [] + if isinstance(v, list): + return v + return [v] + + class ChromaCollection(BaseCollection): - """Thin adapter over a ChromaDB collection.""" + """Thin adapter translating ChromaDB dict returns into typed results.""" def __init__(self, collection): self._collection = collection - def add(self, *, documents, ids, metadatas=None): - self._collection.add(documents=documents, ids=ids, metadatas=metadatas) + # ------------------------------------------------------------------ + # Writes + # ------------------------------------------------------------------ - def upsert(self, *, documents, ids, metadatas=None): - self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) + def add(self, *, documents, ids, metadatas=None, embeddings=None): + kwargs: dict[str, Any] = {"documents": documents, "ids": ids} + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings + self._collection.add(**kwargs) - def update(self, **kwargs): + def upsert(self, *, documents, ids, metadatas=None, embeddings=None): + kwargs: dict[str, Any] = {"documents": documents, "ids": ids} + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings + self._collection.upsert(**kwargs) + + def update( + self, + *, + ids, + documents=None, + metadatas=None, + embeddings=None, + ): + if documents is None and metadatas is None and embeddings is None: + raise ValueError("update requires at least one of documents, metadatas, embeddings") + kwargs: dict[str, Any] = {"ids": ids} + if documents is not None: + kwargs["documents"] = documents + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings self._collection.update(**kwargs) - def query(self, **kwargs): - return self._collection.query(**kwargs) + # ------------------------------------------------------------------ + # Reads + # ------------------------------------------------------------------ - def get(self, **kwargs): - return self._collection.get(**kwargs) + def query( + self, + *, + query_texts=None, + query_embeddings=None, + n_results=10, + where=None, + where_document=None, + include=None, + ) -> QueryResult: + _validate_where(where) + _validate_where(where_document) - def delete(self, **kwargs): + spec = _IncludeSpec.resolve(include, default_distances=True) + chroma_include: list[str] = [] + if spec.documents: + chroma_include.append("documents") + if spec.metadatas: + chroma_include.append("metadatas") + if spec.distances: + chroma_include.append("distances") + if spec.embeddings: + chroma_include.append("embeddings") + + kwargs: dict[str, Any] = { + "n_results": n_results, + "include": chroma_include, + } + if query_texts is not None: + kwargs["query_texts"] = query_texts + if query_embeddings is not None: + kwargs["query_embeddings"] = query_embeddings + if where is not None: + kwargs["where"] = where + if where_document is not None: + kwargs["where_document"] = where_document + + raw = self._collection.query(**kwargs) + + num_queries = ( + len(query_texts) + if query_texts is not None + else (len(query_embeddings) if query_embeddings is not None else 1) + ) + + ids = raw.get("ids") or [] + if not ids: + return QueryResult.empty(num_queries=num_queries) + + documents = raw.get("documents") or [[] for _ in ids] + metadatas = raw.get("metadatas") or [[] for _ in ids] + distances = raw.get("distances") or [[] for _ in ids] + embeddings_raw = raw.get("embeddings") if spec.embeddings else None + + def _none_list_to_empty(outer): + return [(inner or []) for inner in outer] + + return QueryResult( + ids=_none_list_to_empty(ids), + documents=_none_list_to_empty(documents), + metadatas=_none_list_to_empty(metadatas), + distances=_none_list_to_empty(distances), + embeddings=( + [list(inner) for inner in embeddings_raw] + if spec.embeddings and embeddings_raw is not None + else None + ), + ) + + def get( + self, + *, + ids=None, + where=None, + where_document=None, + limit=None, + offset=None, + include=None, + ) -> GetResult: + _validate_where(where) + _validate_where(where_document) + + spec = _IncludeSpec.resolve(include, default_distances=False) + chroma_include: list[str] = [] + if spec.documents: + chroma_include.append("documents") + if spec.metadatas: + chroma_include.append("metadatas") + if spec.embeddings: + chroma_include.append("embeddings") + + kwargs: dict[str, Any] = {"include": chroma_include} + if ids is not None: + kwargs["ids"] = ids + if where is not None: + kwargs["where"] = where + if where_document is not None: + kwargs["where_document"] = where_document + if limit is not None: + kwargs["limit"] = limit + if offset is not None: + kwargs["offset"] = offset + + raw = self._collection.get(**kwargs) + out_ids = list(raw.get("ids") or []) + out_docs = list(raw.get("documents") or []) if spec.documents else [] + out_metas = list(raw.get("metadatas") or []) if spec.metadatas else [] + out_embeds = raw.get("embeddings") if spec.embeddings else None + + # Pad doc/meta lists to match ids so downstream zipping is safe. + if spec.documents and len(out_docs) < len(out_ids): + out_docs = out_docs + [""] * (len(out_ids) - len(out_docs)) + if spec.metadatas and len(out_metas) < len(out_ids): + out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas)) + + return GetResult( + ids=out_ids, + documents=out_docs, + metadatas=out_metas, + embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None, + ) + + def delete(self, *, ids=None, where=None): + _validate_where(where) + kwargs: dict[str, Any] = {} + if ids is not None: + kwargs["ids"] = ids + if where is not None: + kwargs["where"] = where self._collection.delete(**kwargs) def count(self): return self._collection.count() -class ChromaBackend: - """Factory for MemPalace's default ChromaDB backend.""" +# --------------------------------------------------------------------------- +# Backend +# --------------------------------------------------------------------------- + + +class ChromaBackend(BaseBackend): + """MemPalace's default ChromaDB backend. + + Maintains two caches: + + * ``self._clients`` — ``palace_path -> PersistentClient`` for callers + using the ``PalaceRef`` / :meth:`get_collection` path. + * An inode+mtime freshness check absorbed from ``mcp_server._get_client`` + (merged via #757) ensuring a palace rebuild on disk is detected on the + next :meth:`get_collection` call. + """ + + name = "chroma" + capabilities = frozenset( + { + "supports_embeddings_in", + "supports_embeddings_passthrough", + "supports_embeddings_out", + "supports_metadata_filters", + "supports_contains_fast", + "local_mode", + } + ) def __init__(self): - # Per-instance client cache: palace_path -> chromadb.PersistentClient - self._clients: dict = {} + # palace_path -> PersistentClient + self._clients: dict[str, Any] = {} + # palace_path -> (inode, mtime) of chroma.sqlite3 at cache time. + self._freshness: dict[str, tuple[int, float]] = {} + self._closed = False # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _db_stat(palace_path: str) -> tuple[int, float]: + """Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent.""" + db_path = os.path.join(palace_path, "chroma.sqlite3") + try: + st = os.stat(db_path) + return (st.st_ino, st.st_mtime) + except OSError: + return (0, 0.0) + def _client(self, palace_path: str): - """Return a cached PersistentClient for *palace_path*, creating one if needed.""" - if palace_path not in self._clients: + """Return a cached ``PersistentClient``, rebuilding on inode/mtime change. + + Handles the palace-rebuild case (repair/nuke/purge) by invalidating the + cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0, + so inode comparisons only fire when non-zero (matches #757 semantics). + """ + if self._closed: + from .base import BackendClosedError # late import avoids cycles at module load + + raise BackendClosedError("ChromaBackend has been closed") + + cached = self._clients.get(palace_path) + cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0)) + current_inode, current_mtime = self._db_stat(palace_path) + + inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode + mtime_changed = ( + current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime + ) + + if cached is None or inode_changed or mtime_changed: _fix_blob_seq_ids(palace_path) - self._clients[palace_path] = chromadb.PersistentClient(path=palace_path) - return self._clients[palace_path] + cached = chromadb.PersistentClient(path=palace_path) + self._clients[palace_path] = cached + self._freshness[palace_path] = (current_inode, current_mtime) + return cached # ------------------------------------------------------------------ - # Public static helpers (for callers that manage their own caching) + # Public static helpers (legacy; prefer :meth:`get_collection`) # ------------------------------------------------------------------ @staticmethod def make_client(palace_path: str): - """Create and return a fresh PersistentClient (fix BLOB seq_ids first). + """Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first). - Intended for long-lived callers (e.g. mcp_server) that keep their own - inode/mtime-based client cache. + Deprecated-ish: exposed for legacy long-lived callers that manage their + own client cache. New code should obtain a collection through + :meth:`get_collection` which manages caching internally. """ _fix_blob_seq_ids(palace_path) return chromadb.PersistentClient(path=palace_path) @@ -109,12 +377,31 @@ class ChromaBackend: return chromadb.__version__ # ------------------------------------------------------------------ - # Collection lifecycle + # BaseBackend surface # ------------------------------------------------------------------ - def get_collection(self, palace_path: str, collection_name: str, create: bool = False): + def get_collection( + self, + *args, + **kwargs, + ) -> ChromaCollection: + """Obtain a collection for a palace. + + Supports two calling conventions during the RFC 001 transition: + + * New (preferred): ``get_collection(palace=PalaceRef, collection_name=..., + create=False, options=None)``. + * Legacy: ``get_collection(palace_path, collection_name, create=False)`` + — still used by callers not yet migrated. + """ + palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs) + + palace_path = palace_ref.local_path + if palace_path is None: + raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path") + if not create and not os.path.isdir(palace_path): - raise FileNotFoundError(palace_path) + raise PalaceNotFoundError(palace_path) if create: os.makedirs(palace_path, exist_ok=True) @@ -124,29 +411,113 @@ class ChromaBackend: pass client = self._client(palace_path) + hnsw_space = "cosine" + if options and isinstance(options, dict): + hnsw_space = options.get("hnsw_space", hnsw_space) + if create: collection = client.get_or_create_collection( - collection_name, metadata={"hnsw:space": "cosine"} + collection_name, metadata={"hnsw:space": hnsw_space} ) else: collection = client.get_collection(collection_name) return ChromaCollection(collection) - def get_or_create_collection( - self, palace_path: str, collection_name: str - ) -> "ChromaCollection": - """Shorthand for get_collection(..., create=True).""" + def close_palace(self, palace) -> None: + """Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str.""" + path = palace.local_path if isinstance(palace, PalaceRef) else palace + if path is None: + return + self._clients.pop(path, None) + self._freshness.pop(path, None) + + def close(self) -> None: + self._clients.clear() + self._freshness.clear() + self._closed = True + + def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus: + if self._closed: + return HealthStatus.unhealthy("backend closed") + return HealthStatus.healthy() + + @classmethod + def detect(cls, path: str) -> bool: + return os.path.isfile(os.path.join(path, "chroma.sqlite3")) + + # ------------------------------------------------------------------ + # Legacy (pre-RFC 001) surface — retained while callers migrate. + # ------------------------------------------------------------------ + + def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection: + """Legacy shim for ``get_collection(..., create=True)`` by path string.""" return self.get_collection(palace_path, collection_name, create=True) def delete_collection(self, palace_path: str, collection_name: str) -> None: - """Delete *collection_name* from the palace at *palace_path*.""" + """Delete ``collection_name`` from the palace at ``palace_path``.""" self._client(palace_path).delete_collection(collection_name) def create_collection( self, palace_path: str, collection_name: str, hnsw_space: str = "cosine" - ) -> "ChromaCollection": - """Create (not get-or-create) *collection_name* with cosine HNSW space.""" + ) -> ChromaCollection: + """Create (not get-or-create) ``collection_name`` with the given HNSW space.""" collection = self._client(palace_path).create_collection( collection_name, metadata={"hnsw:space": hnsw_space} ) return ChromaCollection(collection) + + +def _normalize_get_collection_args(args, kwargs): + """Unify legacy positional ``(palace_path, collection_name, create)`` calls + with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``. + + Returns ``(PalaceRef, collection_name, create, options)``. + """ + # New-style: palace= kwarg with a PalaceRef (spec path). + if "palace" in kwargs: + palace_ref = kwargs.pop("palace") + if not isinstance(palace_ref, PalaceRef): + raise TypeError("palace= must be a PalaceRef instance") + collection_name = kwargs.pop("collection_name") + create = kwargs.pop("create", False) + options = kwargs.pop("options", None) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + if args: + raise TypeError("positional args not allowed with palace= kwarg") + return palace_ref, collection_name, create, options + + # Legacy: first positional is a path string. + if args: + palace_path = args[0] + rest = list(args[1:]) + collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None) + if collection_name is None: + raise TypeError("collection_name is required") + create = kwargs.pop("create", False) + if rest: + create = rest.pop(0) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + return ( + PalaceRef(id=palace_path, local_path=palace_path), + collection_name, + bool(create), + None, + ) + + # Legacy kwargs-only (palace_path=..., collection_name=..., create=...) + if "palace_path" in kwargs: + palace_path = kwargs.pop("palace_path") + collection_name = kwargs.pop("collection_name") + create = kwargs.pop("create", False) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + return ( + PalaceRef(id=palace_path, local_path=palace_path), + collection_name, + bool(create), + None, + ) + + raise TypeError("get_collection requires palace= or a positional palace_path") diff --git a/mempalace/backends/registry.py b/mempalace/backends/registry.py new file mode 100644 index 0000000..7551bd3 --- /dev/null +++ b/mempalace/backends/registry.py @@ -0,0 +1,189 @@ +"""Backend registry + entry-point discovery (RFC 001 §3). + +Third-party backends ship as installable packages that declare a +``mempalace.backends`` entry point:: + + # pyproject.toml of mempalace-postgres + [project.entry-points."mempalace.backends"] + postgres = "mempalace_postgres:PostgresBackend" + +MemPalace discovers them at process start. In-tree tests and local development +can register manually via :func:`register`. Explicit registration wins on +name conflict (matches RFC 001 §3.2). +""" + +from __future__ import annotations + +import logging +from importlib import metadata +from threading import Lock +from typing import Optional, Type + +from .base import BaseBackend + +logger = logging.getLogger(__name__) + +_ENTRY_POINT_GROUP = "mempalace.backends" + +_registry: dict[str, Type[BaseBackend]] = {} +_instances: dict[str, BaseBackend] = {} +_explicit: set[str] = set() +_discovered = False +_lock = Lock() + + +def register(name: str, backend_cls: Type[BaseBackend]) -> None: + """Register ``backend_cls`` under ``name``. + + Explicit registration wins over entry-point discovery on conflict + (RFC 001 §3.2). + """ + with _lock: + _registry[name] = backend_cls + _explicit.add(name) + # Invalidate any cached instance so the new class is used on next get. + _instances.pop(name, None) + + +def unregister(name: str) -> None: + """Remove a backend registration (primarily for tests).""" + with _lock: + _registry.pop(name, None) + _explicit.discard(name) + _instances.pop(name, None) + + +def _discover_entry_points() -> None: + """Load entry-point-declared backends once per process.""" + global _discovered + if _discovered: + return + with _lock: + if _discovered: + return + try: + eps = metadata.entry_points() + # Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict. + group = ( + eps.select(group=_ENTRY_POINT_GROUP) + if hasattr(eps, "select") + else eps.get(_ENTRY_POINT_GROUP, []) + ) + except Exception: + logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP) + group = [] + for ep in group: + if ep.name in _explicit: + continue # explicit registration wins + try: + cls = ep.load() + except Exception: + logger.exception("failed to load backend entry point %r", ep.name) + continue + if not isinstance(cls, type) or not issubclass(cls, BaseBackend): + logger.warning( + "entry point %r did not resolve to a BaseBackend subclass (got %r)", + ep.name, + cls, + ) + continue + _registry.setdefault(ep.name, cls) + _discovered = True + + +def available_backends() -> list[str]: + """Return sorted list of all registered backend names.""" + _discover_entry_points() + return sorted(_registry.keys()) + + +def get_backend_class(name: str) -> Type[BaseBackend]: + """Return the registered backend class for ``name``.""" + _discover_entry_points() + try: + return _registry[name] + except KeyError as e: + raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e + + +def get_backend(name: str) -> BaseBackend: + """Return a long-lived instance of the named backend. + + Instances are cached per-name; repeated calls return the same object. + Call :func:`reset_backends` in tests that need isolation. + """ + _discover_entry_points() + with _lock: + inst = _instances.get(name) + if inst is not None: + return inst + cls = _registry.get(name) + if cls is None: + raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}") + inst = cls() + _instances[name] = inst + return inst + + +def reset_backends() -> None: + """Close and drop all cached backend instances (primarily for tests).""" + with _lock: + for inst in _instances.values(): + try: + inst.close() + except Exception: + logger.exception("error closing backend during reset") + _instances.clear() + + +def resolve_backend_for_palace( + *, + explicit: Optional[str] = None, + config_value: Optional[str] = None, + env_value: Optional[str] = None, + palace_path: Optional[str] = None, + default: str = "chroma", +) -> str: + """Resolve the backend name for a palace per RFC 001 §3.3 priority order. + + 1. Explicit kwarg / CLI flag + 2. Per-palace config value + 3. ``MEMPALACE_BACKEND`` env var + 4. Auto-detect from on-disk artifacts (migration/upgrade path only) + 5. Default (``chroma``) + + Auto-detection is strictly a migration aid: it fires only when a local path + is presented, no earlier rule has chosen a backend, AND the path already + contains backend-identifiable artifacts. For new palaces, (5) wins. + """ + for candidate in (explicit, config_value, env_value): + if candidate: + return candidate + + _discover_entry_points() + if palace_path: + for name, cls in _registry.items(): + try: + if cls.detect(palace_path): + return name + except Exception: + logger.exception("detect() raised on backend %r", name) + continue + return default + + +# --------------------------------------------------------------------------- +# Built-in registration +# --------------------------------------------------------------------------- + + +def _register_builtins() -> None: + """Register chroma as the in-tree default.""" + from .chroma import ChromaBackend + + # Use setdefault semantics so a caller that pre-registered for tests wins. + if "chroma" not in _registry: + _registry["chroma"] = ChromaBackend + + +_register_builtins() diff --git a/mempalace/searcher.py b/mempalace/searcher.py index db809d9..081d3a7 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -30,15 +30,16 @@ class SearchError(Exception): _TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE) -def _first_or_empty(results: dict, key: str) -> list: - """Return the first inner list of a ChromaDB query result, or []. +def _first_or_empty(results, key: str) -> list: + """Return the first inner list of a query result field, or []. - ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a - successful query, but ``{"documents": [], ...}`` (empty outer list) when - the collection is empty or the filter excludes everything. Indexing - ``[0]`` blindly raises IndexError in that case (issue #195). + Accepts both the typed :class:`QueryResult` (attribute access) and the + pre-typed chroma dict shape; this polymorphism is retained so test mocks + still work and callers mid-migration do not crash. Preserves the empty- + collection semantics from issue #195: when no queries returned hits, the + outer list may be empty and indexing ``[0]`` would raise. """ - outer = results.get(key) + outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key) if not outer: return [] return outer[0] or [] @@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None} indexed_docs = [] - for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []): + for doc, meta in zip(neighbors.documents, neighbors.metadatas): ci = meta.get("chunk_index") if isinstance(ci, int): indexed_docs.append((ci, doc)) @@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra total_drawers = None try: all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"]) - ids = all_meta.get("ids") or [] - total_drawers = len(ids) if ids else None + total_drawers = len(all_meta.ids) if all_meta.ids else None except Exception: pass @@ -451,8 +451,8 @@ def search_memories( ) except Exception: continue - docs = source_drawers.get("documents") or [] - metas_ = source_drawers.get("metadatas") or [] + docs = source_drawers.documents + metas_ = source_drawers.metadatas if len(docs) <= 1: continue diff --git a/pyproject.toml b/pyproject.toml index f3067f3..e03dbe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace" [project.scripts] mempalace = "mempalace.cli:main" +[project.entry-points."mempalace.backends"] +chroma = "mempalace.backends.chroma:ChromaBackend" + [project.optional-dependencies] dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"] spellcheck = ["autocorrect>=2.0"] diff --git a/tests/test_backends.py b/tests/test_backends.py index a620bf9..6535691 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -3,12 +3,34 @@ import sqlite3 import chromadb import pytest +from mempalace.backends import ( + GetResult, + PalaceRef, + QueryResult, + UnsupportedFilterError, + available_backends, + get_backend, +) from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids class _FakeCollection: - def __init__(self): + """Stand-in for a chromadb.Collection returning raw chroma-shaped dicts.""" + + def __init__(self, query_response=None, get_response=None, count_value=7): self.calls = [] + self._query_response = query_response or { + "ids": [["a", "b"]], + "documents": [["da", "db"]], + "metadatas": [[{"wing": "w1"}, {"wing": "w2"}]], + "distances": [[0.1, 0.2]], + } + self._get_response = get_response or { + "ids": ["a"], + "documents": ["da"], + "metadatas": [{"wing": "w1"}], + } + self._count_value = count_value def add(self, **kwargs): self.calls.append(("add", kwargs)) @@ -16,41 +38,142 @@ class _FakeCollection: def upsert(self, **kwargs): self.calls.append(("upsert", kwargs)) + def update(self, **kwargs): + self.calls.append(("update", kwargs)) + def query(self, **kwargs): self.calls.append(("query", kwargs)) - return {"kind": "query"} + return self._query_response def get(self, **kwargs): self.calls.append(("get", kwargs)) - return {"kind": "get"} + return self._get_response def delete(self, **kwargs): self.calls.append(("delete", kwargs)) def count(self): self.calls.append(("count", {})) - return 7 + return self._count_value -def test_chroma_collection_delegates_methods(): +def test_chroma_collection_returns_typed_query_result(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + result = collection.query(query_texts=["q"]) + + assert isinstance(result, QueryResult) + assert result.ids == [["a", "b"]] + assert result.documents == [["da", "db"]] + assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]] + assert result.distances == [[0.1, 0.2]] + assert result.embeddings is None + + +def test_chroma_collection_returns_typed_get_result(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + result = collection.get(where={"wing": "w1"}) + + assert isinstance(result, GetResult) + assert result.ids == ["a"] + assert result.documents == ["da"] + assert result.metadatas == [{"wing": "w1"}] + + +def test_query_result_empty_preserves_outer_dimension(): + empty = QueryResult.empty(num_queries=2) + assert empty.ids == [[], []] + assert empty.documents == [[], []] + assert empty.distances == [[], []] + assert empty.embeddings is None + + +def test_typed_results_support_dict_compat_access(): + """Transitional compat shim per base.py — retained until callers migrate to attrs.""" + result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}]) + assert result["ids"] == ["a"] + assert result.get("documents") == ["da"] + assert result.get("missing", "default") == "default" + assert "ids" in result + assert "missing" not in result + + +def test_chroma_collection_query_empty_result_preserves_outer_shape(): + fake = _FakeCollection( + query_response={"ids": [], "documents": [], "metadatas": [], "distances": []} + ) + collection = ChromaCollection(fake) + + result = collection.query(query_texts=["q1", "q2"]) + assert result.ids == [[], []] + assert result.documents == [[], []] + assert result.distances == [[], []] + + +def test_chroma_collection_rejects_unknown_where_operator(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + with pytest.raises(UnsupportedFilterError): + collection.query(query_texts=["q"], where={"$regex": "foo"}) + + +def test_chroma_collection_delegates_writes(): fake = _FakeCollection() collection = ChromaCollection(fake) collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}]) collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}]) - assert collection.query(query_texts=["q"]) == {"kind": "query"} - assert collection.get(where={"wing": "w"}) == {"kind": "get"} collection.delete(ids=["1"]) assert collection.count() == 7 - assert fake.calls == [ - ("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}), - ("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}), - ("query", {"query_texts": ["q"]}), - ("get", {"where": {"wing": "w"}}), - ("delete", {"ids": ["1"]}), - ("count", {}), - ] + kinds = [call[0] for call in fake.calls] + assert kinds == ["add", "upsert", "delete", "count"] + + +def test_registry_exposes_chroma_by_default(): + names = available_backends() + assert "chroma" in names + assert isinstance(get_backend("chroma"), ChromaBackend) + + +def test_registry_unknown_backend_raises(): + with pytest.raises(KeyError): + get_backend("no-such-backend-exists") + + +def test_resolve_backend_priority_order(tmp_path): + from mempalace.backends import resolve_backend_for_palace + + # explicit kwarg wins over everything + assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg" + # config value wins over env / default + assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance" + # env wins over default + assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant" + # falls back to default + assert resolve_backend_for_palace() == "chroma" + + +def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path): + (tmp_path / "chroma.sqlite3").write_bytes(b"") + assert ChromaBackend.detect(str(tmp_path)) is True + assert ChromaBackend.detect(str(tmp_path.parent)) is False + + +def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path): + palace_path = tmp_path / "palace" + backend = ChromaBackend() + collection = backend.get_collection( + palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), + collection_name="mempalace_drawers", + create=True, + ) + assert palace_path.is_dir() + assert isinstance(collection, ChromaCollection) def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):