import contextlib import json from pathlib import Path from unittest.mock import patch from mempalace.hooks_cli import ( SAVE_INTERVAL, STOP_BLOCK_REASON, PRECOMPACT_BLOCK_REASON, _count_human_messages, _sanitize_session_id, hook_stop, hook_session_start, hook_precompact, ) # --- _sanitize_session_id --- def test_sanitize_normal_id(): assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ" def test_sanitize_strips_dangerous_chars(): assert _sanitize_session_id("../../etc/passwd") == "etcpasswd" def test_sanitize_empty_returns_unknown(): assert _sanitize_session_id("") == "unknown" assert _sanitize_session_id("!!!") == "unknown" # --- _count_human_messages --- def _write_transcript(path: Path, entries: list[dict]): with open(path, "w", encoding="utf-8") as f: for entry in entries: f.write(json.dumps(entry) + "\n") def test_count_human_messages_basic(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [ {"message": {"role": "user", "content": "hello"}}, {"message": {"role": "assistant", "content": "hi"}}, {"message": {"role": "user", "content": "bye"}}, ], ) assert _count_human_messages(str(transcript)) == 2 def test_count_skips_command_messages(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [ {"message": {"role": "user", "content": "status"}}, {"message": {"role": "user", "content": "real question"}}, ], ) assert _count_human_messages(str(transcript)) == 1 def test_count_handles_list_content(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [ {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, { "message": { "role": "user", "content": [{"type": "text", "text": "x"}], } }, ], ) assert _count_human_messages(str(transcript)) == 1 def test_count_missing_file(): assert _count_human_messages("/nonexistent/path.jsonl") == 0 def test_count_empty_file(tmp_path): transcript = tmp_path / "t.jsonl" transcript.write_text("") assert _count_human_messages(str(transcript)) == 0 def test_count_malformed_json_lines(tmp_path): transcript = tmp_path / "t.jsonl" transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n') assert _count_human_messages(str(transcript)) == 1 # --- hook_stop --- def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None): """Run a hook and capture its JSON stdout output.""" import io buf = io.StringIO() patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))] if state_dir: patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir)) with contextlib.ExitStack() as stack: for p in patches: stack.enter_context(p) hook_fn(data, harness) return json.loads(buf.getvalue()) def test_stop_hook_passthrough_when_active(tmp_path): with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": True, "transcript_path": ""}, state_dir=tmp_path, ) assert result == {} def test_stop_hook_passthrough_when_active_string(tmp_path): with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": "true", "transcript_path": ""}, state_dir=tmp_path, ) assert result == {} def test_stop_hook_passthrough_below_interval(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)], ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, state_dir=tmp_path, ) assert result == {} def test_stop_hook_blocks_at_interval(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, state_dir=tmp_path, ) assert result["decision"] == "block" assert result["reason"] == STOP_BLOCK_REASON def test_stop_hook_tracks_save_point(tmp_path): transcript = tmp_path / "t.jsonl" _write_transcript( transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], ) data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)} # First call blocks result = _capture_hook_output(hook_stop, data, state_dir=tmp_path) assert result["decision"] == "block" # Second call with same count passes through (already saved) result = _capture_hook_output(hook_stop, data, state_dir=tmp_path) assert result == {} # --- hook_session_start --- def test_session_start_passes_through(tmp_path): result = _capture_hook_output( hook_session_start, {"session_id": "test"}, state_dir=tmp_path, ) assert result == {} # --- hook_precompact --- def test_precompact_always_blocks(tmp_path): result = _capture_hook_output( hook_precompact, {"session_id": "test"}, state_dir=tmp_path, ) assert result["decision"] == "block" assert result["reason"] == PRECOMPACT_BLOCK_REASON