Fix: set cosine distance metadata on all collection creation sites

ChromaDB defaults HNSW index to L2 (Euclidean) distance, but
MemPalace scoring uses 1-distance which requires cosine (range 0-2).
Add metadata={"hnsw:space": "cosine"} to the 4 production and 3 test
call sites that were missing it.

Closes #218
This commit is contained in:
eblander
2026-04-13 11:00:52 -04:00
parent 6614b9b4e7
commit 1e86892e62
7 changed files with 224 additions and 66 deletions
+6 -2
View File
@@ -101,7 +101,9 @@ def config(tmp_dir, palace_path):
def collection(palace_path):
"""A ChromaDB collection pre-seeded in the temp palace."""
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
col = client.get_or_create_collection(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
yield col
client.delete_collection("mempalace_drawers")
del client
@@ -185,7 +187,9 @@ def seeded_kg(kg):
kg.add_triple("Alice", "parent_of", "Max", valid_from="2015-04-01")
kg.add_triple("Max", "does", "swimming", valid_from="2025-01-01")
kg.add_triple("Max", "does", "chess", valid_from="2024-06-01")
kg.add_triple("Alice", "works_at", "Acme Corp", valid_from="2020-01-01", valid_to="2024-12-31")
kg.add_triple(
"Alice", "works_at", "Acme Corp", valid_from="2020-01-01", valid_to="2024-12-31"
)
kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01")
return kg
+14
View File
@@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
client.get_collection("mempalace_drawers")
def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path):
palace_path = tmp_path / "palace"
ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
)
client = chromadb.PersistentClient(path=str(palace_path))
col = client.get_collection("mempalace_drawers")
assert col.metadata.get("hnsw:space") == "cosine"
def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path):
"""Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair."""
db_path = tmp_path / "chroma.sqlite3"
+77 -27
View File
@@ -30,7 +30,12 @@ def _get_collection(palace_path, create=False):
client = chromadb.PersistentClient(path=palace_path)
if create:
return client, client.get_or_create_collection("mempalace_drawers")
return (
client,
client.get_or_create_collection(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
),
)
return client, client.get_collection("mempalace_drawers")
@@ -92,7 +97,9 @@ class TestHandleRequest:
def test_notifications_initialized_returns_none(self):
from mempalace.mcp_server import handle_request
resp = handle_request({"method": "notifications/initialized", "id": None, "params": {}})
resp = handle_request(
{"method": "notifications/initialized", "id": None, "params": {}}
)
assert resp is None
def test_ping_returns_empty_result(self):
@@ -113,7 +120,9 @@ class TestHandleRequest:
assert "mempalace_add_drawer" in names
assert "mempalace_kg_add" in names
def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg):
def test_null_arguments_does_not_hang(
self, monkeypatch, config, palace_path, seeded_kg
):
"""Sending arguments: null should return a result, not hang (#394)."""
_patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import handle_request
@@ -218,7 +227,9 @@ class TestReadTools:
assert result["total_drawers"] == 0
assert result["wings"] == {}
def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_status_with_data(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_status
@@ -235,7 +246,9 @@ class TestReadTools:
assert result["wings"]["project"] == 3
assert result["wings"]["notes"] == 1
def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_list_rooms_all(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_rooms
@@ -244,7 +257,9 @@ class TestReadTools:
assert "frontend" in result["rooms"]
assert "planning" in result["rooms"]
def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_list_rooms_filtered(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_rooms
@@ -252,7 +267,9 @@ class TestReadTools:
assert "backend" in result["rooms"]
assert "planning" not in result["rooms"]
def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_get_taxonomy(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_get_taxonomy
@@ -273,7 +290,9 @@ class TestReadTools:
class TestSearchTool:
def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_search_basic(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search
@@ -284,14 +303,18 @@ class TestSearchTool:
top = result["results"][0]
assert "JWT" in top["text"] or "authentication" in top["text"].lower()
def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_search_with_wing_filter(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search
result = tool_search(query="planning", wing="notes")
assert all(r["wing"] == "notes" for r in result["results"])
def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_search_with_room_filter(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search
@@ -310,7 +333,9 @@ class TestSearchTool:
assert "results" in result
# Old name takes precedence when both provided
result_strict = tool_search(query="JWT", max_distance=999.0, min_similarity=0.01)
result_strict = tool_search(
query="JWT", max_distance=999.0, min_similarity=0.01
)
result_loose = tool_search(query="JWT", max_distance=0.01, min_similarity=999.0)
assert len(result_strict["results"]) <= len(result_loose["results"])
@@ -318,7 +343,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
assert "error" in result
@@ -327,7 +352,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail())
result = mcp_server.tool_search(query="JWT", room="../backend")
assert "error" in result
@@ -336,7 +361,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_list_drawers(wing="../notes")
assert "error" in result
@@ -345,7 +370,7 @@ class TestSearchTool:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
result = mcp_server.tool_find_tunnels(wing_a="../project")
assert "error" in result
@@ -402,7 +427,9 @@ class TestWriteTools:
assert result2["success"] is True
assert result2["reason"] == "already_exists"
def test_add_drawer_shared_header_no_collision(self, monkeypatch, config, palace_path, kg):
def test_add_drawer_shared_header_no_collision(
self, monkeypatch, config, palace_path, kg
):
"""Documents sharing a >100-char header must get distinct IDs (full-content hash)."""
_patch_mcp_server(monkeypatch, config, kg)
_client, _col = _get_collection(palace_path, create=True)
@@ -414,7 +441,10 @@ class TestWriteTools:
header
+ "Decision: Use PostgreSQL for primary storage. Rationale: ACID compliance required."
)
doc2 = header + "Decision: Use Redis for session caching. Rationale: sub-ms latency needed."
doc2 = (
header
+ "Decision: Use Redis for session caching. Rationale: sub-ms latency needed."
)
result1 = tool_add_drawer(wing="work", room="decisions", content=doc1)
result2 = tool_add_drawer(wing="work", room="decisions", content=doc2)
@@ -425,7 +455,9 @@ class TestWriteTools:
result1["drawer_id"] != result2["drawer_id"]
), "Documents with shared header but different content must have distinct drawer IDs"
def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_delete_drawer(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_delete_drawer
@@ -433,14 +465,18 @@ class TestWriteTools:
assert result["success"] is True
assert seeded_collection.count() == 3
def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_delete_drawer_not_found(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_delete_drawer
result = tool_delete_drawer("nonexistent_drawer")
assert result["success"] is False
def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_check_duplicate(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_check_duplicate
@@ -469,14 +505,18 @@ class TestWriteTools:
assert result["room"] == "backend"
assert "JWT tokens" in result["content"]
def test_get_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_get_drawer_not_found(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_get_drawer
result = tool_get_drawer("nonexistent_drawer")
assert "error" in result
def test_list_drawers(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_list_drawers(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_drawers
@@ -504,7 +544,9 @@ class TestWriteTools:
assert result["count"] == 2
assert all(d["room"] == "backend" for d in result["drawers"])
def test_list_drawers_pagination(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_list_drawers_pagination(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_drawers
@@ -522,7 +564,9 @@ class TestWriteTools:
result = tool_list_drawers(offset=-5)
assert result["offset"] == 0
def test_update_drawer_content(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_update_drawer_content(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_update_drawer, tool_get_drawer
@@ -540,19 +584,25 @@ class TestWriteTools:
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_update_drawer
result = tool_update_drawer("drawer_proj_backend_aaa", wing="new_wing", room="new_room")
result = tool_update_drawer(
"drawer_proj_backend_aaa", wing="new_wing", room="new_room"
)
assert result["success"] is True
assert result["wing"] == "new_wing"
assert result["room"] == "new_room"
def test_update_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_update_drawer_not_found(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_update_drawer
result = tool_update_drawer("nonexistent_drawer", content="hello")
assert result["success"] is False
def test_update_drawer_noop(self, monkeypatch, config, palace_path, seeded_collection, kg):
def test_update_drawer_noop(
self, monkeypatch, config, palace_path, seeded_collection, kg
):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_update_drawer
+20 -7
View File
@@ -27,7 +27,8 @@ def test_project_mining():
os.makedirs(project_root / "backend")
write_file(
project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20
project_root / "backend" / "app.py",
"def main():\n print('hello world')\n" * 20,
)
with open(project_root / "mempalace.yaml", "w") as f:
yaml.dump(
@@ -59,7 +60,9 @@ def test_scan_project_respects_gitignore():
write_file(project_root / ".gitignore", "ignored.py\ngenerated/\n")
write_file(project_root / "src" / "app.py", "print('hello')\n" * 20)
write_file(project_root / "ignored.py", "print('ignore me')\n" * 20)
write_file(project_root / "generated" / "artifact.py", "print('artifact')\n" * 20)
write_file(
project_root / "generated" / "artifact.py", "print('artifact')\n" * 20
)
assert scanned_files(project_root) == ["src/app.py"]
finally:
@@ -74,7 +77,9 @@ def test_scan_project_respects_nested_gitignore():
write_file(project_root / ".gitignore", "*.log\n")
write_file(project_root / "subrepo" / ".gitignore", "tasks/\n")
write_file(project_root / "subrepo" / "src" / "main.py", "print('main')\n" * 20)
write_file(project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20)
write_file(
project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20
)
write_file(project_root / "subrepo" / "debug.log", "debug\n" * 20)
assert scanned_files(project_root) == ["subrepo/src/main.py"]
@@ -133,7 +138,9 @@ def test_scan_project_can_disable_gitignore():
write_file(project_root / ".gitignore", "data/\n")
write_file(project_root / "data" / "stuff.csv", "a,b,c\n" * 20)
assert scanned_files(project_root, respect_gitignore=False) == ["data/stuff.csv"]
assert scanned_files(project_root, respect_gitignore=False) == [
"data/stuff.csv"
]
finally:
shutil.rmtree(tmpdir)
@@ -146,7 +153,9 @@ def test_scan_project_can_include_ignored_directory():
write_file(project_root / ".gitignore", "docs/\n")
write_file(project_root / "docs" / "guide.md", "# Guide\n" * 20)
assert scanned_files(project_root, include_ignored=["docs"]) == ["docs/guide.md"]
assert scanned_files(project_root, include_ignored=["docs"]) == [
"docs/guide.md"
]
finally:
shutil.rmtree(tmpdir)
@@ -215,7 +224,9 @@ def test_file_already_mined_check_mtime():
palace_path = os.path.join(tmpdir, "palace")
os.makedirs(palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
col = client.get_or_create_collection(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
test_file = os.path.join(tmpdir, "test.txt")
with open(test_file, "w") as f:
@@ -269,7 +280,9 @@ def test_mine_dry_run_with_tiny_file_no_crash():
project_root = Path(tmpdir).resolve()
# One normal file and one that falls below MIN_CHUNK_SIZE
write_file(project_root / "good.py", "def main():\n print('hello world')\n" * 20)
write_file(
project_root / "good.py", "def main():\n print('hello world')\n" * 20
)
write_file(project_root / "tiny.txt", "x")
with open(project_root / "mempalace.yaml", "w") as f: