Merge pull request #807 from sha2fiddy/fix/218-cosine-distance-metadata

Fix: set cosine distance metadata on all collection creation sites
This commit is contained in:
Igor Lins e Silva
2026-04-13 21:18:40 -03:00
committed by GitHub
7 changed files with 53 additions and 17 deletions
+3 -1
View File
@@ -85,7 +85,9 @@ class ChromaBackend:
_fix_blob_seq_ids(palace_path)
client = chromadb.PersistentClient(path=palace_path)
if create:
collection = client.get_or_create_collection(collection_name)
collection = client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "cosine"}
)
else:
collection = client.get_collection(collection_name)
return ChromaCollection(collection)
+17 -5
View File
@@ -156,7 +156,11 @@ def cmd_migrate(args):
from .migrate import migrate
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False))
migrate(
palace_path=palace_path,
dry_run=args.dry_run,
confirm=getattr(args, "yes", False),
)
def cmd_status(args):
@@ -240,7 +244,7 @@ def cmd_repair(args):
print(" Rebuilding collection...")
client.delete_collection("mempalace_drawers")
new_col = client.create_collection("mempalace_drawers")
new_col = client.create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
filed = 0
for i in range(0, len(all_ids), batch_size):
@@ -328,7 +332,11 @@ def cmd_compress(args):
offset = 0
while True:
try:
kwargs = {"include": ["documents", "metadatas"], "limit": _BATCH, "offset": offset}
kwargs = {
"include": ["documents", "metadatas"],
"limit": _BATCH,
"offset": offset,
}
if where:
kwargs["where"] = where
batch = col.get(**kwargs)
@@ -386,7 +394,9 @@ def cmd_compress(args):
# Store compressed versions (unless dry-run)
if not args.dry_run:
try:
comp_col = client.get_or_create_collection("mempalace_compressed")
comp_col = client.get_or_create_collection(
"mempalace_compressed", metadata={"hnsw:space": "cosine"}
)
for doc_id, compressed, meta, stats in compressed_entries:
comp_meta = dict(meta)
comp_meta["compression_ratio"] = round(stats["size_ratio"], 1)
@@ -431,7 +441,9 @@ def main():
p_init = sub.add_parser("init", help="Detect rooms from your folder structure")
p_init.add_argument("dir", help="Project directory to set up")
p_init.add_argument(
"--yes", action="store_true", help="Auto-accept all detected entities (non-interactive)"
"--yes",
action="store_true",
help="Auto-accept all detected entities (non-interactive)",
)
# mine
+5 -3
View File
@@ -33,13 +33,15 @@ def extract_drawers_from_sqlite(db_path: str) -> list:
conn.row_factory = sqlite3.Row
# Get all embedding IDs and their documents
rows = conn.execute("""
rows = conn.execute(
"""
SELECT e.embedding_id,
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
FROM embeddings e
JOIN embedding_metadata em ON em.id = e.id
GROUP BY e.embedding_id
""").fetchall()
"""
).fetchall()
drawers = []
for row in rows:
@@ -207,7 +209,7 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
print(f" Creating fresh palace in {temp_palace}...")
client = chromadb.PersistentClient(path=temp_palace)
col = client.get_or_create_collection("mempalace_drawers")
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
# Re-import in batches
batch_size = 500
+1 -1
View File
@@ -101,7 +101,7 @@ def config(tmp_dir, palace_path):
def collection(palace_path):
"""A ChromaDB collection pre-seeded in the temp palace."""
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
yield col
client.delete_collection("mempalace_drawers")
del client
+14
View File
@@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
client.get_collection("mempalace_drawers")
def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path):
palace_path = tmp_path / "palace"
ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
)
client = chromadb.PersistentClient(path=str(palace_path))
col = client.get_collection("mempalace_drawers")
assert col.metadata.get("hnsw:space") == "cosine"
def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path):
"""Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair."""
db_path = tmp_path / "chroma.sqlite3"
+8 -5
View File
@@ -31,7 +31,10 @@ def _get_collection(palace_path, create=False):
client = chromadb.PersistentClient(path=palace_path)
if create:
return client, client.get_or_create_collection("mempalace_drawers")
return (
client,
client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}),
)
return client, client.get_collection("mempalace_drawers")
@@ -319,7 +322,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
assert "error" in result
@@ -328,7 +331,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail())
result = mcp_server.tool_search(query="JWT", room="../backend")
assert "error" in result
@@ -337,7 +340,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_list_drawers(wing="../notes")
assert "error" in result
@@ -346,7 +349,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_find_tunnels(wing_a="../project")
assert "error" in result
+5 -2
View File
@@ -27,7 +27,8 @@ def test_project_mining():
os.makedirs(project_root / "backend")
write_file(
project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20
project_root / "backend" / "app.py",
"def main():\n print('hello world')\n" * 20,
)
with open(project_root / "mempalace.yaml", "w") as f:
yaml.dump(
@@ -215,7 +216,9 @@ def test_file_already_mined_check_mtime():
palace_path = os.path.join(tmpdir, "palace")
os.makedirs(palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
col = client.get_or_create_collection(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
test_file = os.path.join(tmpdir, "test.txt")
with open(test_file, "w") as f: