Merge pull request #807 from sha2fiddy/fix/218-cosine-distance-metadata
Fix: set cosine distance metadata on all collection creation sites
This commit is contained in:
+1
-1
@@ -101,7 +101,7 @@ def config(tmp_dir, palace_path):
|
||||
def collection(palace_path):
|
||||
"""A ChromaDB collection pre-seeded in the temp palace."""
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
|
||||
yield col
|
||||
client.delete_collection("mempalace_drawers")
|
||||
del client
|
||||
|
||||
@@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
|
||||
client.get_collection("mempalace_drawers")
|
||||
|
||||
|
||||
def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path):
|
||||
palace_path = tmp_path / "palace"
|
||||
|
||||
ChromaBackend().get_collection(
|
||||
str(palace_path),
|
||||
collection_name="mempalace_drawers",
|
||||
create=True,
|
||||
)
|
||||
|
||||
client = chromadb.PersistentClient(path=str(palace_path))
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
assert col.metadata.get("hnsw:space") == "cosine"
|
||||
|
||||
|
||||
def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path):
|
||||
"""Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair."""
|
||||
db_path = tmp_path / "chroma.sqlite3"
|
||||
|
||||
@@ -31,7 +31,10 @@ def _get_collection(palace_path, create=False):
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
if create:
|
||||
return client, client.get_or_create_collection("mempalace_drawers")
|
||||
return (
|
||||
client,
|
||||
client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}),
|
||||
)
|
||||
return client, client.get_collection("mempalace_drawers")
|
||||
|
||||
|
||||
@@ -319,7 +322,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
|
||||
assert "error" in result
|
||||
@@ -328,7 +331,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_search(query="JWT", room="../backend")
|
||||
assert "error" in result
|
||||
@@ -337,7 +340,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_list_drawers(wing="../notes")
|
||||
assert "error" in result
|
||||
@@ -346,7 +349,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_find_tunnels(wing_a="../project")
|
||||
assert "error" in result
|
||||
|
||||
+5
-2
@@ -27,7 +27,8 @@ def test_project_mining():
|
||||
os.makedirs(project_root / "backend")
|
||||
|
||||
write_file(
|
||||
project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20
|
||||
project_root / "backend" / "app.py",
|
||||
"def main():\n print('hello world')\n" * 20,
|
||||
)
|
||||
with open(project_root / "mempalace.yaml", "w") as f:
|
||||
yaml.dump(
|
||||
@@ -215,7 +216,9 @@ def test_file_already_mined_check_mtime():
|
||||
palace_path = os.path.join(tmpdir, "palace")
|
||||
os.makedirs(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection(
|
||||
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
test_file = os.path.join(tmpdir, "test.txt")
|
||||
with open(test_file, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user