fix(hooks): address Copilot review feedback on #1021

- _output(): use sys.modules.get() instead of unconditional import to
  avoid triggering mcp_server's stdout redirect as a side effect
- _output(): write-all loop for os.write() to handle partial writes and
  EINTR; fall back to sys.stdout.buffer on OSError
- _output() docstring: remove inaccurate _save_diary_direct reference
- stop_hook_active guard: narrow except to ImportError/AttributeError,
  default silent_guard=False (safe: preserves block-mode loop prevention
  when config load fails) and log a warning instead of silently changing
  behavior
- tests: two new regression tests covering the real-stdout-fd path and
  the fd-1 fallback path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
jp
2026-04-18 15:48:19 -07:00
parent 6a3a5c7a3d
commit 5deb815f0b
2 changed files with 93 additions and 17 deletions
+33 -17
View File
@@ -134,24 +134,35 @@ def _log(message: str):
def _output(data: dict): def _output(data: dict):
"""Print JSON to the real stdout, even if mcp_server has hijacked sys.stdout. """Print JSON to stdout without importing modules that may redirect streams.
mempalace.mcp_server redirects stdout → stderr at module import (fd and If mempalace.mcp_server is already loaded, reuse its saved real stdout fd.
sys-level) to protect the MCP stdio protocol from ChromaDB's C-level Otherwise, write directly to fd 1 so hook responses still go to stdout even
prints. Silent-save imports it transitively via _save_diary_direct, so if sys.stdout has been redirected elsewhere.
sys.stdout is stderr by the time we get here. Claude Code reads hook
output from fd 1, so we write there directly using the saved fd.
""" """
payload = json.dumps(data, indent=2, ensure_ascii=False) + "\n" payload = (json.dumps(data, indent=2, ensure_ascii=False) + "\n").encode("utf-8")
real_stdout_fd: int | None = None
mcp_mod = sys.modules.get("mempalace.mcp_server") or sys.modules.get(
f"{__package__}.mcp_server" if __package__ else "mcp_server"
)
if mcp_mod is not None:
real_stdout_fd = getattr(mcp_mod, "_REAL_STDOUT_FD", None)
fd = real_stdout_fd if real_stdout_fd is not None else 1
offset = 0
try: try:
from .mcp_server import _REAL_STDOUT_FD while offset < len(payload):
if _REAL_STDOUT_FD is not None: try:
os.write(_REAL_STDOUT_FD, payload.encode("utf-8")) offset += os.write(fd, payload[offset:])
return except InterruptedError:
except Exception: continue
return
except OSError:
pass pass
sys.stdout.write(payload)
sys.stdout.flush() sys.stdout.buffer.write(payload)
sys.stdout.buffer.flush()
def _get_mine_dir(transcript_path: str = "") -> str: def _get_mine_dir(transcript_path: str = "") -> str:
@@ -230,11 +241,16 @@ def hook_stop(data: dict, harness: str):
# no loop to prevent — and Claude Code's plugin dispatch sets this flag on every # no loop to prevent — and Claude Code's plugin dispatch sets this flag on every
# fire after the first, which would otherwise suppress all subsequent auto-saves. # fire after the first, which would otherwise suppress all subsequent auto-saves.
if str(stop_hook_active).lower() in ("true", "1", "yes"): if str(stop_hook_active).lower() in ("true", "1", "yes"):
silent_guard = False
try: try:
from .config import MempalaceConfig from .config import MempalaceConfig
silent_guard = MempalaceConfig().hook_silent_save except ImportError as exc:
except Exception: _log(f"WARNING: could not import MempalaceConfig for stop guard: {exc}; preserving block-mode guard")
silent_guard = True else:
try:
silent_guard = MempalaceConfig().hook_silent_save
except AttributeError as exc:
_log(f"WARNING: could not read hook_silent_save: {exc}; preserving block-mode guard")
if not silent_guard: if not silent_guard:
_output({}) _output({})
return return
+60
View File
@@ -218,6 +218,66 @@ def test_precompact_allows(tmp_path):
# --- _log --- # --- _log ---
def test_output_writes_to_real_stdout_fd_when_mcp_server_loaded():
"""_output() must reach fd 1 even when mcp_server has redirected sys.stdout."""
import types
fake_module = types.ModuleType("mempalace.mcp_server")
read_fd, write_fd = os.pipe()
try:
fake_module._REAL_STDOUT_FD = write_fd
with patch.dict("sys.modules", {"mempalace.mcp_server": fake_module}):
from mempalace.hooks_cli import _output
_output({"systemMessage": "test"})
os.close(write_fd)
written = b""
while True:
chunk = os.read(read_fd, 4096)
if not chunk:
break
written += chunk
finally:
os.close(read_fd)
data = json.loads(written.decode())
assert data["systemMessage"] == "test"
def test_output_falls_back_to_fd1_when_mcp_server_absent():
"""_output() writes to fd 1 directly when mcp_server is not loaded."""
read_fd, write_fd = os.pipe()
try:
orig_fd1 = os.dup(1)
os.dup2(write_fd, 1)
os.close(write_fd)
try:
modules_without_mcp = {k: v for k, v in __import__("sys").modules.items()
if "mcp_server" not in k}
with patch.dict("sys.modules", modules_without_mcp, clear=True):
from mempalace.hooks_cli import _output
_output({"continue": True})
finally:
os.dup2(orig_fd1, 1)
os.close(orig_fd1)
except Exception:
os.close(read_fd)
raise
written = b""
while True:
chunk = os.read(read_fd, 4096)
if not chunk:
break
written += chunk
os.close(read_fd)
data = json.loads(written.decode())
assert data["continue"] is True
def test_log_writes_to_hook_log(tmp_path): def test_log_writes_to_hook_log(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_log("test message") _log("test message")