"""Tests for explicit tunnel helpers in mempalace.palace_graph.""" from unittest.mock import MagicMock, patch import pytest with patch.dict("sys.modules", {"chromadb": MagicMock()}): import mempalace.palace_graph as palace_graph def _use_tmp_tunnel_file(monkeypatch, tmp_path): tunnel_file = tmp_path / "tunnels.json" monkeypatch.setattr(palace_graph, "_TUNNEL_FILE", str(tunnel_file)) return tunnel_file class TestTunnelStorage: def test_load_tunnels_missing_file_returns_empty_list(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) assert palace_graph._load_tunnels() == [] def test_load_tunnels_corrupt_file_returns_empty_list(self, tmp_path, monkeypatch): tunnel_file = _use_tmp_tunnel_file(monkeypatch, tmp_path) tunnel_file.write_text("{not valid json", encoding="utf-8") assert palace_graph._load_tunnels() == [] def test_save_and_load_round_trip(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) tunnels = [ { "id": "abc123", "source": {"wing": "wing_code", "room": "auth"}, "target": {"wing": "wing_people", "room": "users"}, "label": "same concept", } ] palace_graph._save_tunnels(tunnels) assert palace_graph._load_tunnels() == tunnels class TestExplicitTunnels: def test_create_tunnel_deduplicates_reverse_order_and_updates_label( self, tmp_path, monkeypatch ): _use_tmp_tunnel_file(monkeypatch, tmp_path) first = palace_graph.create_tunnel( "wing_code", "auth", "wing_people", "users", label="same concept" ) second = palace_graph.create_tunnel( "wing_people", "users", "wing_code", "auth", label="updated label" ) assert first["id"] == second["id"] assert len(palace_graph.list_tunnels()) == 1 assert second["label"] == "updated label" assert second["created_at"] == first["created_at"] assert "updated_at" in second def test_create_tunnel_rejects_empty_names(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) with pytest.raises(ValueError): palace_graph.create_tunnel("", "auth", "wing_people", "users") def test_list_tunnels_filters_by_either_side(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) palace_graph.create_tunnel("wing_code", "auth", "wing_people", "users", label="A") palace_graph.create_tunnel("wing_ops", "deploy", "wing_people", "users", label="B") assert len(palace_graph.list_tunnels()) == 2 assert len(palace_graph.list_tunnels("wing_people")) == 2 assert len(palace_graph.list_tunnels("wing_code")) == 1 def test_delete_tunnel_removes_saved_tunnel(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) tunnel = palace_graph.create_tunnel( "wing_code", "auth", "wing_people", "users", label="same concept" ) assert palace_graph.delete_tunnel(tunnel["id"]) == {"deleted": tunnel["id"]} assert palace_graph.list_tunnels() == [] def test_follow_tunnels_returns_direction_and_preview(self, tmp_path, monkeypatch): _use_tmp_tunnel_file(monkeypatch, tmp_path) palace_graph.create_tunnel( "wing_code", "auth", "wing_people", "users", label="same concept", target_drawer_id="drawer_users_1", ) col = MagicMock() col.get.return_value = { "ids": ["drawer_users_1"], "documents": ["A" * 400], "metadatas": [{}], } outgoing = palace_graph.follow_tunnels("wing_code", "auth", col=col) assert len(outgoing) == 1 assert outgoing[0]["direction"] == "outgoing" assert outgoing[0]["connected_wing"] == "wing_people" assert outgoing[0]["connected_room"] == "users" assert outgoing[0]["drawer_id"] == "drawer_users_1" assert len(outgoing[0]["drawer_preview"]) == 300 incoming = palace_graph.follow_tunnels("wing_people", "users", col=col) assert len(incoming) == 1 assert incoming[0]["direction"] == "incoming" assert incoming[0]["connected_wing"] == "wing_code" def test_follow_tunnels_returns_connections_even_if_collection_lookup_fails( self, tmp_path, monkeypatch ): _use_tmp_tunnel_file(monkeypatch, tmp_path) palace_graph.create_tunnel( "wing_code", "auth", "wing_people", "users", label="same concept", target_drawer_id="drawer_users_1", ) col = MagicMock() col.get.side_effect = RuntimeError("boom") connections = palace_graph.follow_tunnels("wing_code", "auth", col=col) assert len(connections) == 1 assert "drawer_preview" not in connections[0]