diff --git a/mempalace/onboarding.py b/mempalace/onboarding.py index f578d91..70f7b54 100644 --- a/mempalace/onboarding.py +++ b/mempalace/onboarding.py @@ -312,7 +312,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines)) + (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8") # Critical facts bootstrap (pre-palace — before any mining) facts_lines = [ @@ -359,7 +359,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines)) + (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8") def run_onboarding( diff --git a/pyproject.toml b/pyproject.toml index 00e5c93..3a6a8b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ testpaths = ["tests"] source = ["mempalace"] [tool.coverage.report] -fail_under = 30 +fail_under = 50 show_missing = true exclude_lines = [ "if __name__", diff --git a/tests/test_entity_detector.py b/tests/test_entity_detector.py new file mode 100644 index 0000000..b3d378b --- /dev/null +++ b/tests/test_entity_detector.py @@ -0,0 +1,263 @@ +"""Tests for mempalace.entity_detector.""" + +import os + +from mempalace.entity_detector import ( + PROSE_EXTENSIONS, + STOPWORDS, + classify_entity, + detect_entities, + extract_candidates, + scan_for_detection, + score_entity, +) + + +# ── extract_candidates ────────────────────────────────────────────────── + + +def test_extract_candidates_finds_frequent_names(): + text = "Riley said hello. Riley laughed. Riley smiled. Riley waved." + result = extract_candidates(text) + assert "Riley" in result + assert result["Riley"] >= 3 + + +def test_extract_candidates_ignores_stopwords(): + # "The" appears many times but is a stopword + text = "The The The The The The" + result = extract_candidates(text) + assert "The" not in result + + +def test_extract_candidates_requires_min_frequency(): + text = "Riley said hi. Devon waved." + result = extract_candidates(text) + # Each name appears only once, below the threshold of 3 + assert "Riley" not in result + assert "Devon" not in result + + +def test_extract_candidates_finds_multi_word_names(): + # Multi-word names need 3+ occurrences and no stopwords + text = ( + "Claude Code is great. Claude Code rocks. " + "Claude Code works. Claude Code rules." + ) + result = extract_candidates(text) + assert "Claude Code" in result + + +def test_extract_candidates_empty_text(): + result = extract_candidates("") + assert result == {} + + +# ── score_entity ──────────────────────────────────────────────────────── + + +def test_score_entity_person_verbs(): + text = "Riley said hello. Riley asked why. Riley told me." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + assert len(result["person_signals"]) > 0 + + +def test_score_entity_project_verbs(): + text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + assert len(result["project_signals"]) > 0 + + +def test_score_entity_dialogue_markers(): + text = "Riley: Hey, how are you?\nRiley: I'm fine." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + + +def test_score_entity_code_ref(): + text = "Check out ChromaDB.py for details. Also ChromaDB.js is good." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + + +def test_score_entity_no_signals(): + text = "Nothing interesting here at all." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] == 0 + assert result["project_score"] == 0 + + +# ── classify_entity ───────────────────────────────────────────────────── + + +def test_classify_entity_no_signals_gives_uncertain(): + scores = { + "person_score": 0, + "project_score": 0, + "person_signals": [], + "project_signals": [], + } + result = classify_entity("Foo", 10, scores) + assert result["type"] == "uncertain" + assert result["name"] == "Foo" + + +def test_classify_entity_strong_project(): + scores = { + "person_score": 0, + "project_score": 10, + "person_signals": [], + "project_signals": ["project verb (5x)", "code file reference (2x)"], + } + result = classify_entity("ChromaDB", 5, scores) + assert result["type"] == "project" + + +def test_classify_entity_strong_person_needs_two_signal_types(): + scores = { + "person_score": 10, + "project_score": 0, + "person_signals": [ + "dialogue marker (3x)", + "'Riley ...' action (4x)", + ], + "project_signals": [], + } + result = classify_entity("Riley", 8, scores) + assert result["type"] == "person" + + +def test_classify_entity_pronoun_only_is_uncertain(): + scores = { + "person_score": 8, + "project_score": 0, + "person_signals": ["pronoun nearby (4x)"], + "project_signals": [], + } + result = classify_entity("Riley", 5, scores) + assert result["type"] == "uncertain" + + +def test_classify_entity_mixed_signals(): + scores = { + "person_score": 5, + "project_score": 5, + "person_signals": ["pronoun nearby (2x)"], + "project_signals": ["project verb (2x)"], + } + result = classify_entity("Lantern", 5, scores) + assert result["type"] == "uncertain" + assert "mixed signals" in result["signals"][-1] + + +# ── detect_entities (integration) ─────────────────────────────────────── + + +def test_detect_entities_with_person_file(tmp_path): + f = tmp_path / "notes.txt" + content = "\n".join( + [ + "Riley said hello today.", + "Riley asked about the project.", + "Riley told me she was happy.", + "Riley: I think we should go.", + "Hey Riley, thanks for the help.", + "Riley laughed and smiled.", + "Riley decided to join.", + "Riley pushed the change.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Riley" in all_names + + +def test_detect_entities_with_project_file(tmp_path): + f = tmp_path / "readme.txt" + # "ChromaDB" has uppercase+lowercase mix but extract_candidates looks + # for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex. + # Use "Lantern" which matches the capitalized-word pattern. + content = "\n".join( + [ + "The Lantern project is great.", + "Building Lantern was fun.", + "We deployed Lantern today.", + "Install Lantern with pip install Lantern.", + "Check Lantern.py for the source.", + "Lantern v2 is faster.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Lantern" in all_names + + +def test_detect_entities_empty_files(tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + result = detect_entities([f]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_handles_missing_file(tmp_path): + missing = tmp_path / "nonexistent.txt" + result = detect_entities([missing]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_respects_max_files(tmp_path): + files = [] + for i in range(5): + f = tmp_path / f"file{i}.txt" + f.write_text("Riley said hello. " * 10) + files.append(f) + # max_files=2 should only read 2 files + result = detect_entities(files, max_files=2) + # Should still work without error + assert isinstance(result, dict) + + +# ── scan_for_detection ────────────────────────────────────────────────── + + +def test_scan_for_detection_finds_prose(tmp_path): + (tmp_path / "notes.md").write_text("hello") + (tmp_path / "data.txt").write_text("world") + (tmp_path / "code.py").write_text("import os") + files = scan_for_detection(str(tmp_path)) + extensions = {os.path.splitext(str(f))[1] for f in files} + # Prose files should be found + assert ".md" in extensions or ".txt" in extensions + + +def test_scan_for_detection_skips_git_dir(tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config.txt").write_text("git config") + (tmp_path / "readme.md").write_text("hello") + files = scan_for_detection(str(tmp_path)) + file_strs = [str(f) for f in files] + assert not any(".git" in f for f in file_strs) + + +# ── module-level constants ────────────────────────────────────────────── + + +def test_stopwords_contains_common_words(): + assert "the" in STOPWORDS + assert "import" in STOPWORDS + assert "class" in STOPWORDS + + +def test_prose_extensions(): + assert ".txt" in PROSE_EXTENSIONS + assert ".md" in PROSE_EXTENSIONS diff --git a/tests/test_entity_registry.py b/tests/test_entity_registry.py new file mode 100644 index 0000000..b92bf84 --- /dev/null +++ b/tests/test_entity_registry.py @@ -0,0 +1,313 @@ +"""Tests for mempalace.entity_registry.""" + +from unittest.mock import patch + +from mempalace.entity_registry import ( + COMMON_ENGLISH_WORDS, + PERSON_CONTEXT_PATTERNS, + EntityRegistry, +) + + +# ── COMMON_ENGLISH_WORDS ──────────────────────────────────────────────── + + +def test_common_english_words_has_expected_entries(): + assert "ever" in COMMON_ENGLISH_WORDS + assert "grace" in COMMON_ENGLISH_WORDS + assert "will" in COMMON_ENGLISH_WORDS + assert "may" in COMMON_ENGLISH_WORDS + assert "monday" in COMMON_ENGLISH_WORDS + + +def test_common_english_words_is_lowercase(): + for word in COMMON_ENGLISH_WORDS: + assert word == word.lower(), f"{word} should be lowercase" + + +# ── PERSON_CONTEXT_PATTERNS ───────────────────────────────────────────── + + +def test_person_context_patterns_is_nonempty(): + assert len(PERSON_CONTEXT_PATTERNS) > 0 + + +# ── EntityRegistry creation and empty state ───────────────────────────── + + +def test_load_from_nonexistent_dir(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + assert registry.people == {} + assert registry.projects == [] + assert registry.mode == "personal" + assert registry.ambiguous_flags == [] + + +def test_save_and_load_roundtrip(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["MemPalace"], + ) + # Load again from same dir + loaded = EntityRegistry.load(config_dir=tmp_path) + assert loaded.mode == "work" + assert "Alice" in loaded.people + assert "MemPalace" in loaded.projects + + +def test_save_creates_file(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.save() + assert (tmp_path / "entity_registry.json").exists() + + +# ── seed ──────────────────────────────────────────────────────────────── + + +def test_seed_registers_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=["MemPalace"], + ) + assert "Riley" in registry.people + assert "Devon" in registry.people + assert registry.people["Riley"]["relationship"] == "daughter" + assert registry.people["Riley"]["source"] == "onboarding" + assert registry.people["Riley"]["confidence"] == 1.0 + + +def test_seed_registers_projects(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["Acme", "Widget"]) + assert registry.projects == ["Acme", "Widget"] + + +def test_seed_sets_mode(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="combo", people=[], projects=[]) + assert registry.mode == "combo" + + +def test_seed_flags_ambiguous_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Grace", "relationship": "friend", "context": "personal"}, + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + ], + projects=[], + ) + assert "grace" in registry.ambiguous_flags + # Riley is not a common English word + assert "riley" not in registry.ambiguous_flags + + +def test_seed_with_aliases(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + assert "Maxwell" in registry.people + assert "Max" in registry.people + assert registry.people["Max"].get("canonical") == "Maxwell" + + +def test_seed_skips_empty_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "", "relationship": "", "context": "personal"}], + projects=[], + ) + assert len(registry.people) == 0 + + +# ── lookup ────────────────────────────────────────────────────────────── + + +def test_lookup_known_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Riley") + assert result["type"] == "person" + assert result["confidence"] == 1.0 + assert result["name"] == "Riley" + + +def test_lookup_known_project(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["MemPalace"]) + result = registry.lookup("MemPalace") + assert result["type"] == "project" + assert result["confidence"] == 1.0 + + +def test_lookup_unknown_word(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + result = registry.lookup("Xyzzy") + assert result["type"] == "unknown" + assert result["confidence"] == 0.0 + + +def test_lookup_case_insensitive(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("riley") + assert result["type"] == "person" + + +def test_lookup_alias(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + result = registry.lookup("Max") + assert result["type"] == "person" + + +# ── disambiguation ────────────────────────────────────────────────────── + + +def test_lookup_ambiguous_word_as_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Grace", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Grace", context="I went with Grace today") + assert result["type"] == "person" + + +def test_lookup_ambiguous_word_as_concept(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Ever", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Ever", context="have you ever tried this") + assert result["type"] == "concept" + + +# ── research (Wikipedia) — mocked ────────────────────────────────────── + + +def test_research_caches_result(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is an Irish given name.", + "wiki_title": "Saoirse", + } + + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + result = registry.research("Saoirse", auto_confirm=True) + assert result["inferred_type"] == "person" + + # Second call should use cache, not call Wikipedia again + with patch( + "mempalace.entity_registry._wikipedia_lookup", + side_effect=AssertionError("should not be called"), + ): + cached = registry.research("Saoirse") + assert cached["inferred_type"] == "person" + + +def test_confirm_research_adds_to_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is a name", + "wiki_title": "Saoirse", + } + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + registry.research("Saoirse", auto_confirm=False) + + registry.confirm_research("Saoirse", entity_type="person", relationship="friend") + assert "Saoirse" in registry.people + assert registry.people["Saoirse"]["source"] == "wiki" + + +# ── extract_people_from_query ─────────────────────────────────────────── + + +def test_extract_people_from_query(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=[], + ) + found = registry.extract_people_from_query("What did Riley say about the weather?") + assert "Riley" in found + assert "Devon" not in found + + +# ── extract_unknown_candidates ────────────────────────────────────────── + + +def test_extract_unknown_candidates(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + unknowns = registry.extract_unknown_candidates("Saoirse went to the store") + assert "Saoirse" in unknowns + + +def test_extract_unknown_candidates_skips_known(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + unknowns = registry.extract_unknown_candidates("Riley went to the store") + assert "Riley" not in unknowns + + +# ── summary ───────────────────────────────────────────────────────────── + + +def test_summary(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + ) + s = registry.summary() + assert "personal" in s + assert "Riley" in s + assert "MemPalace" in s diff --git a/tests/test_general_extractor.py b/tests/test_general_extractor.py new file mode 100644 index 0000000..6a68fc3 --- /dev/null +++ b/tests/test_general_extractor.py @@ -0,0 +1,242 @@ +"""Tests for mempalace.general_extractor.""" + +from mempalace.general_extractor import ( + ALL_MARKERS, + NEGATIVE_WORDS, + POSITIVE_WORDS, + _extract_prose, + _get_sentiment, + _has_resolution, + _is_code_line, + _score_markers, + _split_into_segments, + extract_memories, +) + + +# ── extract_memories — empty / no markers ─────────────────────────────── + + +def test_extract_memories_empty_text(): + result = extract_memories("") + assert result == [] + + +def test_extract_memories_no_markers(): + result = extract_memories("The quick brown fox jumped over the lazy dog.") + assert result == [] + + +def test_extract_memories_short_text_skipped(): + # Paragraphs shorter than 20 chars are skipped + result = extract_memories("ok sure") + assert result == [] + + +# ── extract_memories — decision markers ───────────────────────────────── + + +def test_extract_memories_decision(): + text = ( + "We decided to go with PostgreSQL instead of MySQL " + "because the performance was better for our use case. " + "The trade-off was more complexity in setup." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "decision" for m in result) + + +# ── extract_memories — preference markers ─────────────────────────────── + + +def test_extract_memories_preference(): + text = ( + "I prefer using snake_case in Python code. " + "Please always use type hints. " + "Never use wildcard imports." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "preference" for m in result) + + +# ── extract_memories — milestone markers ──────────────────────────────── + + +def test_extract_memories_milestone(): + text = ( + "It finally works! After three days of debugging, " + "I figured out the issue. The breakthrough was realizing " + "the config file was cached. Got it working at 2am." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "milestone" for m in result) + + +# ── extract_memories — problem markers ────────────────────────────────── + + +def test_extract_memories_problem(): + text = ( + "There's a critical bug in the auth module. " + "The error keeps crashing the server. " + "The root cause was a missing null check. " + "The problem is that tokens expire silently." + ) + result = extract_memories(text) + assert len(result) >= 1 + types = {m["memory_type"] for m in result} + assert "problem" in types or "milestone" in types # resolved problems become milestones + + +# ── extract_memories — emotional markers ──────────────────────────────── + + +def test_extract_memories_emotional(): + text = ( + "I feel so proud of what we built together. " + "I love working on this project, it makes me happy. " + "I'm grateful for the team and the beautiful code we wrote." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "emotional" for m in result) + + +# ── extract_memories — chunk_index ────────────────────────────────────── + + +def test_extract_memories_chunk_index_increments(): + text = ( + "We decided to use React because it fits our team.\n\n" + "I prefer functional components always.\n\n" + "It works! We finally shipped the v1.0 release." + ) + result = extract_memories(text) + if len(result) >= 2: + indices = [m["chunk_index"] for m in result] + assert indices == list(range(len(result))) + + +# ── _score_markers ────────────────────────────────────────────────────── + + +def test_score_markers_with_matches(): + score, keywords = _score_markers( + "we decided to go with postgres because it is faster", + ALL_MARKERS["decision"], + ) + assert score > 0 + assert len(keywords) > 0 + + +def test_score_markers_no_matches(): + score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"]) + assert score == 0.0 + + +# ── _get_sentiment ────────────────────────────────────────────────────── + + +def test_get_sentiment_positive(): + assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive" + + +def test_get_sentiment_negative(): + assert _get_sentiment("This bug caused a crash and total failure") == "negative" + + +def test_get_sentiment_neutral(): + assert _get_sentiment("The meeting is at three") == "neutral" + + +# ── _has_resolution ───────────────────────────────────────────────────── + + +def test_has_resolution_true(): + assert _has_resolution("I fixed the auth bug and it works now") is True + + +def test_has_resolution_false(): + assert _has_resolution("The server keeps crashing") is False + + +# ── _is_code_line ─────────────────────────────────────────────────────── + + +def test_is_code_line_detects_code(): + assert _is_code_line(" import os") is True + assert _is_code_line(" $ pip install flask") is True + assert _is_code_line(" ```python") is True + + +def test_is_code_line_allows_prose(): + assert _is_code_line("This is a regular sentence about coding.") is False + assert _is_code_line("") is False + + +# ── _extract_prose ────────────────────────────────────────────────────── + + +def test_extract_prose_strips_code_blocks(): + text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye" + result = _extract_prose(text) + assert "import os" not in result + assert "Hello world" in result + assert "Goodbye" in result + + +def test_extract_prose_returns_original_if_all_code(): + text = "import os\nfrom sys import argv" + result = _extract_prose(text) + # Falls back to original text if nothing left + assert len(result) > 0 + + +# ── _split_into_segments ─────────────────────────────────────────────── + + +def test_split_into_segments_by_paragraph(): + text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph." + result = _split_into_segments(text) + assert len(result) == 3 + + +def test_split_into_segments_by_turns(): + lines = [] + for i in range(5): + lines.append(f"Human: Question {i}") + lines.append(f"Assistant: Answer {i}") + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 3 # turn-based splitting should fire + + +def test_split_into_segments_single_block(): + # Many lines without double-newline produces chunked segments + lines = [f"Line {i} of the document" for i in range(30)] + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 1 + + +# ── ALL_MARKERS constant ─────────────────────────────────────────────── + + +def test_all_markers_has_five_types(): + assert set(ALL_MARKERS.keys()) == {"decision", "preference", "milestone", "problem", "emotional"} + + +# ── POSITIVE_WORDS / NEGATIVE_WORDS ──────────────────────────────────── + + +def test_positive_words(): + assert "happy" in POSITIVE_WORDS + assert "proud" in POSITIVE_WORDS + + +def test_negative_words(): + assert "bug" in NEGATIVE_WORDS + assert "crash" in NEGATIVE_WORDS diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index 8eeffed..0ff1fb3 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -1,17 +1,24 @@ import contextlib +import io import json from pathlib import Path from unittest.mock import patch +import pytest + from mempalace.hooks_cli import ( SAVE_INTERVAL, STOP_BLOCK_REASON, PRECOMPACT_BLOCK_REASON, _count_human_messages, + _log, + _maybe_auto_ingest, + _parse_harness_input, _sanitize_session_id, hook_stop, hook_session_start, hook_precompact, + run_hook, ) @@ -190,3 +197,204 @@ def test_precompact_always_blocks(tmp_path): ) assert result["decision"] == "block" assert result["reason"] == PRECOMPACT_BLOCK_REASON + + +# --- _log --- + + +def test_log_writes_to_hook_log(tmp_path): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _log("test message") + log_path = tmp_path / "hook.log" + assert log_path.is_file() + content = log_path.read_text() + assert "test message" in content + + +def test_log_oserror_is_silenced(tmp_path): + """_log should not raise if the directory cannot be created.""" + with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")): + # Should not raise + _log("this will fail silently") + + +# --- _maybe_auto_ingest --- + + +def test_maybe_auto_ingest_no_env(tmp_path): + """Without MEMPAL_DIR set, does nothing.""" + with patch.dict("os.environ", {}, clear=True): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _maybe_auto_ingest() # should not raise + + +def test_maybe_auto_ingest_with_env(tmp_path): + """With MEMPAL_DIR set to a valid directory, spawns subprocess.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen: + _maybe_auto_ingest() + mock_popen.assert_called_once() + + +def test_maybe_auto_ingest_oserror(tmp_path): + """OSError during subprocess spawn is silenced.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")): + _maybe_auto_ingest() # should not raise + + +# --- _parse_harness_input --- + + +def test_parse_harness_input_unknown(): + """Unknown harness should sys.exit(1).""" + with pytest.raises(SystemExit) as exc_info: + _parse_harness_input({"session_id": "test"}, "unknown-harness") + assert exc_info.value.code == 1 + + +def test_parse_harness_input_valid(): + result = _parse_harness_input( + {"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"}, + "claude-code", + ) + assert result["session_id"] == "abc-123" + assert result["stop_hook_active"] is True + + +# --- hook_stop with OSError on write --- + + +def test_stop_hook_oserror_on_last_save_read(tmp_path): + """When last_save_file has invalid content, falls back to 0.""" + transcript = tmp_path / "t.jsonl" + _write_transcript(transcript, [ + {"message": {"role": "user", "content": f"msg {i}"}} + for i in range(SAVE_INTERVAL) + ]) + # Write invalid content to last save file + (tmp_path / "test_last_save").write_text("not_a_number") + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +def test_stop_hook_oserror_on_write(tmp_path): + """When write to last_save_file fails, hook still outputs correctly.""" + transcript = tmp_path / "t.jsonl" + _write_transcript(transcript, [ + {"message": {"role": "user", "content": f"msg {i}"}} + for i in range(SAVE_INTERVAL) + ]) + + def bad_write_text(*args, **kwargs): + raise OSError("disk full") + + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch.object(Path, "write_text", bad_write_text): + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- hook_precompact with MEMPAL_DIR --- + + +def test_precompact_with_mempal_dir(tmp_path): + """Precompact runs subprocess.run when MEMPAL_DIR is set.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run") as mock_run: + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + mock_run.assert_called_once() + + +def test_precompact_with_mempal_dir_oserror(tmp_path): + """Precompact handles OSError from subprocess gracefully.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")): + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- run_hook --- + + +def test_run_hook_dispatches_session_start(tmp_path): + """run_hook reads stdin JSON and dispatches to correct handler.""" + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_stop(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript(transcript, [ + {"message": {"role": "user", "content": f"msg {i}"}} + for i in range(3) + ]) + stdin_data = json.dumps({ + "session_id": "run-test", + "stop_hook_active": False, + "transcript_path": str(transcript), + }) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("stop", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_precompact(tmp_path): + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("precompact", "claude-code") + mock_output.assert_called_once() + call_args = mock_output.call_args[0][0] + assert call_args["decision"] == "block" + + +def test_run_hook_unknown_hook(): + stdin_data = json.dumps({"session_id": "test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with pytest.raises(SystemExit) as exc_info: + run_hook("nonexistent", "claude-code") + assert exc_info.value.code == 1 + + +def test_run_hook_invalid_json(tmp_path): + """Invalid stdin JSON should not crash — falls back to empty dict.""" + with patch("sys.stdin", io.StringIO("not valid json")): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) diff --git a/tests/test_instructions_cli.py b/tests/test_instructions_cli.py new file mode 100644 index 0000000..c99ed14 --- /dev/null +++ b/tests/test_instructions_cli.py @@ -0,0 +1,45 @@ +"""Tests for mempalace.instructions_cli — instruction text output.""" + +from unittest.mock import patch + +import pytest + +from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions + + +def test_run_instructions_valid_name(capsys): + """Valid name prints the .md file content.""" + name = "init" + expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text() + run_instructions(name) + captured = capsys.readouterr() + assert captured.out.strip() == expected.strip() + + +def test_run_instructions_all_available(capsys): + """Every name in AVAILABLE should succeed without error.""" + for name in AVAILABLE: + run_instructions(name) + out = capsys.readouterr().out + assert len(out) > 0 + + +def test_run_instructions_invalid_name(capsys): + """Invalid name should sys.exit(1) and print error to stderr.""" + with pytest.raises(SystemExit) as exc_info: + run_instructions("nonexistent") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Unknown instructions: nonexistent" in captured.err + assert "Available:" in captured.err + + +def test_run_instructions_missing_md_file(capsys, tmp_path): + """If the .md file is missing on disk, should sys.exit(1).""" + with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path): + with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]): + with pytest.raises(SystemExit) as exc_info: + run_instructions("fakecmd") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Instructions file not found" in captured.err diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000..1d06140 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,122 @@ +"""Tests for mempalace.layers — focused on Layer0.""" + +import os +from unittest.mock import patch + +from mempalace.layers import Layer0 + + +# ── Layer0 — with identity file ───────────────────────────────────────── + + +def test_layer0_reads_identity_file(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas, a personal AI assistant for Alice.") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert "Atlas" in text + assert "Alice" in text + + +def test_layer0_caches_text(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("Hello world") + layer = Layer0(identity_path=str(identity_file)) + first = layer.render() + # Modify file after first read + identity_file.write_text("Changed content") + second = layer.render() + # Should return cached version + assert first == second + assert second == "Hello world" + + +def test_layer0_missing_file_returns_default(tmp_path): + missing = str(tmp_path / "nonexistent.txt") + layer = Layer0(identity_path=missing) + text = layer.render() + assert "No identity configured" in text + assert "identity.txt" in text + + +def test_layer0_token_estimate(tmp_path): + identity_file = tmp_path / "identity.txt" + content = "A" * 400 # 400 chars ~ 100 tokens + identity_file.write_text(content) + layer = Layer0(identity_path=str(identity_file)) + estimate = layer.token_estimate() + assert estimate == 100 + + +def test_layer0_token_estimate_empty(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("") + layer = Layer0(identity_path=str(identity_file)) + assert layer.token_estimate() == 0 + + +def test_layer0_strips_whitespace(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text(" Hello world \n\n") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert text == "Hello world" + + +def test_layer0_default_path(): + layer = Layer0() + expected = os.path.expanduser("~/.mempalace/identity.txt") + assert layer.path == expected + + +# ── Layer1 — mocked chromadb ──────────────────────────────────────────── + + +def test_layer1_no_palace(): + """Layer1 returns helpful message when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + from mempalace.layers import Layer1 + + layer = Layer1(palace_path="/nonexistent/palace") + result = layer.generate() + assert "No palace found" in result or "No memories" in result + + +# ── Layer2 — mocked chromadb ──────────────────────────────────────────── + + +def test_layer2_no_palace(): + """Layer2 returns message when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + from mempalace.layers import Layer2 + + layer = Layer2(palace_path="/nonexistent/palace") + result = layer.retrieve(wing="test") + assert "No palace found" in result + + +# ── Layer3 — mocked chromadb ──────────────────────────────────────────── + + +def test_layer3_no_palace(): + """Layer3 returns message when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + from mempalace.layers import Layer3 + + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search("test query") + assert "No palace found" in result + + +def test_layer3_search_raw_no_palace(): + """Layer3.search_raw returns empty list when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + from mempalace.layers import Layer3 + + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search_raw("test query") + assert result == [] diff --git a/tests/test_onboarding.py b/tests/test_onboarding.py new file mode 100644 index 0000000..7a0f903 --- /dev/null +++ b/tests/test_onboarding.py @@ -0,0 +1,172 @@ +"""Tests for mempalace.onboarding.""" + +import os + +from mempalace.onboarding import ( + DEFAULT_WINGS, + _generate_aaak_bootstrap, + _warn_ambiguous, + quick_setup, +) + +# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars) +os.environ["PYTHONUTF8"] = "1" + + +# ── DEFAULT_WINGS ─────────────────────────────────────────────────────── + + +def test_default_wings_has_expected_keys(): + assert "work" in DEFAULT_WINGS + assert "personal" in DEFAULT_WINGS + assert "combo" in DEFAULT_WINGS + + +def test_default_wings_work_has_projects(): + assert "projects" in DEFAULT_WINGS["work"] + + +def test_default_wings_personal_has_family(): + assert "family" in DEFAULT_WINGS["personal"] + + +def test_default_wings_combo_has_both(): + wings = DEFAULT_WINGS["combo"] + assert "family" in wings + assert "work" in wings + + +def test_default_wings_values_are_lists(): + for mode, wings in DEFAULT_WINGS.items(): + assert isinstance(wings, list), f"{mode} wings should be a list" + assert len(wings) >= 3, f"{mode} should have at least 3 wings" + + +# ── _warn_ambiguous ───────────────────────────────────────────────────── + + +def test_warn_ambiguous_flags_common_words(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "Riley", "relationship": "daughter"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + # Riley is not a common English word + assert "Riley" not in result + + +def test_warn_ambiguous_empty_list(): + result = _warn_ambiguous([]) + assert result == [] + + +def test_warn_ambiguous_no_ambiguous_names(): + people = [ + {"name": "Riley", "relationship": "daughter"}, + {"name": "Devon", "relationship": "friend"}, + ] + result = _warn_ambiguous(people) + assert result == [] + + +def test_warn_ambiguous_multiple_hits(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "May", "relationship": "aunt"}, + {"name": "Joy", "relationship": "sister"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + assert "May" in result + assert "Joy" in result + + +# ── quick_setup ───────────────────────────────────────────────────────── + + +def test_quick_setup_creates_registry(tmp_path): + registry = quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + config_dir=tmp_path, + ) + assert "Riley" in registry.people + assert "MemPalace" in registry.projects + assert registry.mode == "personal" + + +def test_quick_setup_work_mode(tmp_path): + registry = quick_setup( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["Acme"], + config_dir=tmp_path, + ) + assert registry.mode == "work" + assert "Alice" in registry.people + assert "Acme" in registry.projects + + +def test_quick_setup_empty(tmp_path): + registry = quick_setup(mode="personal", people=[], config_dir=tmp_path) + assert len(registry.people) == 0 + assert len(registry.projects) == 0 + + +def test_quick_setup_saves_to_disk(tmp_path): + quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + config_dir=tmp_path, + ) + assert (tmp_path / "entity_registry.json").exists() + + +# ── _generate_aaak_bootstrap ─────────────────────────────────────────── + + +def test_generate_aaak_bootstrap_creates_files(tmp_path): + people = [ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ] + projects = ["MemPalace"] + wings = ["family", "creative"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() + + +def test_generate_aaak_bootstrap_entities_content(tmp_path): + people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}] + projects = ["MemPalace"] + wings = ["family"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + content = (tmp_path / "aaak_entities.md").read_text() + assert "Riley" in content + assert "RIL" in content # entity code + assert "MemPalace" in content + + +def test_generate_aaak_bootstrap_facts_content(tmp_path): + people = [ + {"name": "Alice", "relationship": "colleague", "context": "work"}, + ] + projects = ["Acme"] + wings = ["projects"] + _generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path) + + content = (tmp_path / "critical_facts.md").read_text() + assert "Alice" in content + assert "Acme" in content + assert "work" in content.lower() + + +def test_generate_aaak_bootstrap_empty_people(tmp_path): + _generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path) + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() diff --git a/tests/test_palace_graph.py b/tests/test_palace_graph.py new file mode 100644 index 0000000..7875e98 --- /dev/null +++ b/tests/test_palace_graph.py @@ -0,0 +1,218 @@ +"""Tests for mempalace.palace_graph — graph traversal layer. + +All ChromaDB access is mocked — no real database needed. +""" + +from unittest.mock import MagicMock, patch + + +def _make_fake_collection(metadatas, ids=None): + """Create a mock collection that returns the given metadata in batches.""" + if ids is None: + ids = [f"id_{i}" for i in range(len(metadatas))] + + col = MagicMock() + col.count.return_value = len(metadatas) + + def fake_get(limit=1000, offset=0, include=None): + batch_meta = metadatas[offset:offset + limit] + batch_ids = ids[offset:offset + limit] + return {"ids": batch_ids, "metadatas": batch_meta} + + col.get.side_effect = fake_get + return col + + +# Patch chromadb at import time so palace_graph can be imported +with patch.dict("sys.modules", {"chromadb": MagicMock()}): + from mempalace.palace_graph import ( + _fuzzy_match, + build_graph, + find_tunnels, + graph_stats, + traverse, + ) + + +# --- build_graph --- + + +class TestBuildGraph: + def test_empty_collection(self): + col = _make_fake_collection([]) + nodes, edges = build_graph(col=col) + assert nodes == {} + assert edges == [] + + def test_falsy_collection(self): + """When col is explicitly falsy, build_graph returns empty.""" + nodes, edges = build_graph(col=0) + assert nodes == {} + assert edges == [] + + def test_single_wing_no_edges(self): + col = _make_fake_collection([ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"}, + ]) + nodes, edges = build_graph(col=col) + assert "auth" in nodes + assert nodes["auth"]["count"] == 2 + assert edges == [] + + def test_multi_wing_creates_edges(self): + col = _make_fake_collection([ + {"room": "chromadb", "wing": "wing_code", "hall": "databases", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "databases", "date": "2026-01-02"}, + ]) + nodes, edges = build_graph(col=col) + assert "chromadb" in nodes + assert len(edges) == 1 + assert edges[0]["wing_a"] == "wing_code" + assert edges[0]["wing_b"] == "wing_project" + assert edges[0]["hall"] == "databases" + + def test_general_room_excluded(self): + col = _make_fake_collection([ + {"room": "general", "wing": "wing_code", "hall": "misc", "date": ""}, + ]) + nodes, edges = build_graph(col=col) + assert "general" not in nodes + + def test_missing_wing_excluded(self): + col = _make_fake_collection([ + {"room": "orphan", "wing": "", "hall": "misc", "date": ""}, + ]) + nodes, edges = build_graph(col=col) + assert "orphan" not in nodes + + def test_dates_capped_at_five(self): + col = _make_fake_collection([ + {"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"} + for i in range(1, 10) + ]) + nodes, _ = build_graph(col=col) + assert len(nodes["busy"]["dates"]) <= 5 + + +# --- traverse --- + + +class TestTraverse: + def _build_col(self): + return _make_fake_collection([ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"}, + ]) + + def test_traverse_known_room(self): + col = self._build_col() + result = traverse("auth", col=col) + assert isinstance(result, list) + rooms = [r["room"] for r in result] + assert "auth" in rooms + # login shares wing_code with auth + assert "login" in rooms + + def test_traverse_unknown_room(self): + col = self._build_col() + result = traverse("nonexistent", col=col) + assert isinstance(result, dict) + assert "error" in result + assert "suggestions" in result + + def test_traverse_max_hops(self): + col = self._build_col() + result = traverse("auth", col=col, max_hops=0) + # Only the start room itself at hop 0 + assert len(result) == 1 + assert result[0]["room"] == "auth" + + +# --- find_tunnels --- + + +class TestFindTunnels: + def _build_tunnel_col(self): + return _make_fake_collection([ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ]) + + def test_find_all_tunnels(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + def test_find_tunnels_with_wing_filter(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", col=col) + assert len(tunnels) == 1 + + def test_find_tunnels_no_match(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_nonexistent", col=col) + assert tunnels == [] + + def test_find_tunnels_both_wings(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + +# --- graph_stats --- + + +class TestGraphStats: + def test_empty_graph(self): + col = _make_fake_collection([]) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 0 + assert stats["tunnel_rooms"] == 0 + assert stats["total_edges"] == 0 + + def test_stats_with_data(self): + col = _make_fake_collection([ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ]) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 2 + assert stats["tunnel_rooms"] == 1 + assert stats["total_edges"] == 1 + assert "wing_code" in stats["rooms_per_wing"] + + +# --- _fuzzy_match --- + + +class TestFuzzyMatch: + def test_exact_substring(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("chromadb", nodes) + assert "chromadb-setup" in result + + def test_partial_word_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("auth", nodes) + assert "auth-module" in result + + def test_no_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}} + result = _fuzzy_match("zzzzz", nodes) + assert result == [] + + def test_hyphenated_query(self): + nodes = {"riley-college-apps": {}, "college-prep": {}} + result = _fuzzy_match("riley-college", nodes) + assert "riley-college-apps" in result + + def test_max_results(self): + nodes = {f"room-{i}": {} for i in range(20)} + result = _fuzzy_match("room", nodes, n=3) + assert len(result) <= 3 diff --git a/tests/test_room_detector_local.py b/tests/test_room_detector_local.py new file mode 100644 index 0000000..a64949f --- /dev/null +++ b/tests/test_room_detector_local.py @@ -0,0 +1,157 @@ +"""Tests for mempalace.room_detector_local.""" + +from mempalace.room_detector_local import ( + FOLDER_ROOM_MAP, + detect_rooms_from_files, + detect_rooms_from_folders, + save_config, +) + + +# ── FOLDER_ROOM_MAP ──────────────────────────────────────────────────── + + +def test_folder_room_map_has_expected_mappings(): + assert FOLDER_ROOM_MAP["frontend"] == "frontend" + assert FOLDER_ROOM_MAP["backend"] == "backend" + assert FOLDER_ROOM_MAP["docs"] == "documentation" + assert FOLDER_ROOM_MAP["tests"] == "testing" + assert FOLDER_ROOM_MAP["config"] == "configuration" + + +def test_folder_room_map_alternative_names(): + assert FOLDER_ROOM_MAP["front-end"] == "frontend" + assert FOLDER_ROOM_MAP["back-end"] == "backend" + assert FOLDER_ROOM_MAP["server"] == "backend" + assert FOLDER_ROOM_MAP["client"] == "frontend" + assert FOLDER_ROOM_MAP["api"] == "backend" + + +# ── detect_rooms_from_folders ─────────────────────────────────────────── + + +def test_detect_rooms_from_folders_standard_layout(tmp_path): + (tmp_path / "frontend").mkdir() + (tmp_path / "backend").mkdir() + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "frontend" in room_names + assert "backend" in room_names + assert "documentation" in room_names + + +def test_detect_rooms_from_folders_always_has_general(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "general" in room_names + + +def test_detect_rooms_from_folders_empty_dir(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + # Should at least have "general" + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_folders_skips_git(tmp_path): + (tmp_path / ".git").mkdir() + (tmp_path / "node_modules").mkdir() + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert ".git" not in room_names + assert "node_modules" not in room_names + + +def test_detect_rooms_from_folders_nested_dirs(tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "components").mkdir() + (src / "routes").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Nested dirs should be detected at one level deep + assert "frontend" in room_names or "backend" in room_names + + +def test_detect_rooms_from_folders_room_has_description(tmp_path): + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + doc_room = next((r for r in rooms if r["name"] == "documentation"), None) + assert doc_room is not None + assert "description" in doc_room + assert "docs" in doc_room["description"] + + +def test_detect_rooms_from_folders_room_has_keywords(tmp_path): + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + fe_room = next((r for r in rooms if r["name"] == "frontend"), None) + assert fe_room is not None + assert "keywords" in fe_room + assert len(fe_room["keywords"]) > 0 + + +def test_detect_rooms_from_folders_custom_named_dirs(tmp_path): + (tmp_path / "mylib").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Custom dir names that don't match FOLDER_ROOM_MAP get added as-is + assert "mylib" in room_names or "general" in room_names + + +# ── detect_rooms_from_files ───────────────────────────────────────────── + + +def test_detect_rooms_from_files_with_matching_filenames(tmp_path): + # Create files whose names contain room keywords + for name in ["test_auth.py", "test_login.py", "test_api.py"]: + (tmp_path / name).write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "testing" in room_names or "general" in room_names + + +def test_detect_rooms_from_files_empty_dir(tmp_path): + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_files_caps_at_six(tmp_path): + # Create many files with different keywords to hit the cap + for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]: + for i in range(3): + (tmp_path / f"{keyword}_file_{i}.txt").write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) <= 6 + + +# ── save_config ───────────────────────────────────────────────────────── + + +def test_save_config_creates_yaml(tmp_path): + rooms = [ + {"name": "frontend", "description": "UI files", "keywords": ["frontend"]}, + {"name": "backend", "description": "Server files", "keywords": ["backend"]}, + ] + save_config(str(tmp_path), "myproject", rooms) + config_file = tmp_path / "mempalace.yaml" + assert config_file.exists() + content = config_file.read_text() + assert "myproject" in content + assert "frontend" in content + assert "backend" in content + + +def test_save_config_valid_yaml(tmp_path): + import yaml + + rooms = [{"name": "general", "description": "All files", "keywords": []}] + save_config(str(tmp_path), "test_proj", rooms) + config_file = tmp_path / "mempalace.yaml" + data = yaml.safe_load(config_file.read_text()) + assert data["wing"] == "test_proj" + assert len(data["rooms"]) == 1 + assert data["rooms"][0]["name"] == "general" diff --git a/tests/test_spellcheck.py b/tests/test_spellcheck.py new file mode 100644 index 0000000..50da5f0 --- /dev/null +++ b/tests/test_spellcheck.py @@ -0,0 +1,158 @@ +"""Tests for mempalace.spellcheck — spell-correction utilities.""" + +from unittest.mock import patch + +from mempalace.spellcheck import ( + _edit_distance, + _get_system_words, + _should_skip, + spellcheck_transcript, + spellcheck_transcript_line, + spellcheck_user_text, +) + + +# --- _should_skip --- + + +class TestShouldSkip: + """Token-level skip logic.""" + + def test_short_tokens_skipped(self): + assert _should_skip("hi", set()) is True + assert _should_skip("ok", set()) is True + assert _should_skip("I", set()) is True + + def test_digits_skipped(self): + assert _should_skip("3am", set()) is True + assert _should_skip("top10", set()) is True + assert _should_skip("bge-large-v1.5", set()) is True + + def test_camelcase_skipped(self): + assert _should_skip("ChromaDB", set()) is True + assert _should_skip("MemPalace", set()) is True + + def test_allcaps_skipped(self): + assert _should_skip("NDCG", set()) is True + assert _should_skip("MAX_RESULTS", set()) is True + + def test_technical_skipped(self): + assert _should_skip("bge-large", set()) is True + assert _should_skip("train_test", set()) is True + + def test_url_skipped(self): + assert _should_skip("https://example.com", set()) is True + assert _should_skip("www.google.com", set()) is True + + def test_code_or_emoji_skipped(self): + assert _should_skip("`code`", set()) is True + assert _should_skip("**bold**", set()) is True + + def test_known_name_skipped(self): + assert _should_skip("mempalace", {"mempalace"}) is True + + def test_normal_word_not_skipped(self): + assert _should_skip("hello", set()) is False + assert _should_skip("question", set()) is False + + +# --- _edit_distance --- + + +class TestEditDistance: + def test_identical(self): + assert _edit_distance("hello", "hello") == 0 + + def test_empty_strings(self): + assert _edit_distance("", "abc") == 3 + assert _edit_distance("abc", "") == 3 + assert _edit_distance("", "") == 0 + + def test_single_edit(self): + assert _edit_distance("cat", "bat") == 1 # substitution + assert _edit_distance("cat", "cats") == 1 # insertion + assert _edit_distance("cats", "cat") == 1 # deletion + + def test_known_distance(self): + assert _edit_distance("kitten", "sitting") == 3 + + +# --- _get_system_words --- + + +def test_get_system_words_returns_set(): + result = _get_system_words() + assert isinstance(result, set) + + +# --- spellcheck_user_text --- + + +def test_spellcheck_user_text_passthrough_no_autocorrect(): + """When autocorrect is not installed, text passes through unchanged.""" + with patch("mempalace.spellcheck._get_speller", return_value=None): + text = "somee misspeledd textt" + assert spellcheck_user_text(text) == text + + +def test_spellcheck_user_text_with_speller(): + """When a speller is available, it corrects words.""" + def fake_speller(word): + corrections = {"knoe": "know", "befor": "before"} + return corrections.get(word, word) + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("knoe the question befor") + assert "know" in result + assert "before" in result + + +def test_spellcheck_preserves_technical_terms(): + """Technical terms should never be touched even with a speller.""" + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + result = spellcheck_user_text("ChromaDB bge-large", known_names=set()) + assert "ChromaDB" in result + assert "bge-large" in result + assert "WRONG" not in result + + +# --- spellcheck_transcript_line --- + + +def test_transcript_line_user_turn(): + """Lines starting with '>' should be processed.""" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"): + result = spellcheck_transcript_line("> hello world") + assert "corrected" in result + + +def test_transcript_line_assistant_turn(): + """Lines not starting with '>' should pass through unchanged.""" + line = "This is an assistant response" + assert spellcheck_transcript_line(line) == line + + +def test_transcript_line_empty_user_turn(): + """A '> ' line with no message content should pass through.""" + line = "> " + assert spellcheck_transcript_line(line) == line + + +# --- spellcheck_transcript --- + + +def test_spellcheck_transcript_processes_content(): + """Full transcript: only '>' lines are touched.""" + content = "Assistant line\n> user line\nAnother assistant line" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"): + result = spellcheck_transcript(content) + lines = result.split("\n") + assert lines[0] == "Assistant line" + assert "fixed" in lines[1] + assert lines[2] == "Another assistant line"