fix(hooks): address Copilot review feedback on #1021
- _output(): use sys.modules.get() instead of unconditional import to avoid triggering mcp_server's stdout redirect as a side effect - _output(): write-all loop for os.write() to handle partial writes and EINTR; fall back to sys.stdout.buffer on OSError - _output() docstring: remove inaccurate _save_diary_direct reference - stop_hook_active guard: narrow except to ImportError/AttributeError, default silent_guard=False (safe: preserves block-mode loop prevention when config load fails) and log a warning instead of silently changing behavior - tests: two new regression tests covering the real-stdout-fd path and the fd-1 fallback path Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+33
-17
@@ -134,24 +134,35 @@ def _log(message: str):
|
|||||||
|
|
||||||
|
|
||||||
def _output(data: dict):
|
def _output(data: dict):
|
||||||
"""Print JSON to the real stdout, even if mcp_server has hijacked sys.stdout.
|
"""Print JSON to stdout without importing modules that may redirect streams.
|
||||||
|
|
||||||
mempalace.mcp_server redirects stdout → stderr at module import (fd and
|
If mempalace.mcp_server is already loaded, reuse its saved real stdout fd.
|
||||||
sys-level) to protect the MCP stdio protocol from ChromaDB's C-level
|
Otherwise, write directly to fd 1 so hook responses still go to stdout even
|
||||||
prints. Silent-save imports it transitively via _save_diary_direct, so
|
if sys.stdout has been redirected elsewhere.
|
||||||
sys.stdout is stderr by the time we get here. Claude Code reads hook
|
|
||||||
output from fd 1, so we write there directly using the saved fd.
|
|
||||||
"""
|
"""
|
||||||
payload = json.dumps(data, indent=2, ensure_ascii=False) + "\n"
|
payload = (json.dumps(data, indent=2, ensure_ascii=False) + "\n").encode("utf-8")
|
||||||
|
|
||||||
|
real_stdout_fd: int | None = None
|
||||||
|
mcp_mod = sys.modules.get("mempalace.mcp_server") or sys.modules.get(
|
||||||
|
f"{__package__}.mcp_server" if __package__ else "mcp_server"
|
||||||
|
)
|
||||||
|
if mcp_mod is not None:
|
||||||
|
real_stdout_fd = getattr(mcp_mod, "_REAL_STDOUT_FD", None)
|
||||||
|
|
||||||
|
fd = real_stdout_fd if real_stdout_fd is not None else 1
|
||||||
|
offset = 0
|
||||||
try:
|
try:
|
||||||
from .mcp_server import _REAL_STDOUT_FD
|
while offset < len(payload):
|
||||||
if _REAL_STDOUT_FD is not None:
|
try:
|
||||||
os.write(_REAL_STDOUT_FD, payload.encode("utf-8"))
|
offset += os.write(fd, payload[offset:])
|
||||||
return
|
except InterruptedError:
|
||||||
except Exception:
|
continue
|
||||||
|
return
|
||||||
|
except OSError:
|
||||||
pass
|
pass
|
||||||
sys.stdout.write(payload)
|
|
||||||
sys.stdout.flush()
|
sys.stdout.buffer.write(payload)
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
|
||||||
|
|
||||||
def _get_mine_dir(transcript_path: str = "") -> str:
|
def _get_mine_dir(transcript_path: str = "") -> str:
|
||||||
@@ -230,11 +241,16 @@ def hook_stop(data: dict, harness: str):
|
|||||||
# no loop to prevent — and Claude Code's plugin dispatch sets this flag on every
|
# no loop to prevent — and Claude Code's plugin dispatch sets this flag on every
|
||||||
# fire after the first, which would otherwise suppress all subsequent auto-saves.
|
# fire after the first, which would otherwise suppress all subsequent auto-saves.
|
||||||
if str(stop_hook_active).lower() in ("true", "1", "yes"):
|
if str(stop_hook_active).lower() in ("true", "1", "yes"):
|
||||||
|
silent_guard = False
|
||||||
try:
|
try:
|
||||||
from .config import MempalaceConfig
|
from .config import MempalaceConfig
|
||||||
silent_guard = MempalaceConfig().hook_silent_save
|
except ImportError as exc:
|
||||||
except Exception:
|
_log(f"WARNING: could not import MempalaceConfig for stop guard: {exc}; preserving block-mode guard")
|
||||||
silent_guard = True
|
else:
|
||||||
|
try:
|
||||||
|
silent_guard = MempalaceConfig().hook_silent_save
|
||||||
|
except AttributeError as exc:
|
||||||
|
_log(f"WARNING: could not read hook_silent_save: {exc}; preserving block-mode guard")
|
||||||
if not silent_guard:
|
if not silent_guard:
|
||||||
_output({})
|
_output({})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -218,6 +218,66 @@ def test_precompact_allows(tmp_path):
|
|||||||
# --- _log ---
|
# --- _log ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_writes_to_real_stdout_fd_when_mcp_server_loaded():
|
||||||
|
"""_output() must reach fd 1 even when mcp_server has redirected sys.stdout."""
|
||||||
|
import types
|
||||||
|
|
||||||
|
fake_module = types.ModuleType("mempalace.mcp_server")
|
||||||
|
|
||||||
|
read_fd, write_fd = os.pipe()
|
||||||
|
try:
|
||||||
|
fake_module._REAL_STDOUT_FD = write_fd
|
||||||
|
with patch.dict("sys.modules", {"mempalace.mcp_server": fake_module}):
|
||||||
|
from mempalace.hooks_cli import _output
|
||||||
|
|
||||||
|
_output({"systemMessage": "test"})
|
||||||
|
|
||||||
|
os.close(write_fd)
|
||||||
|
written = b""
|
||||||
|
while True:
|
||||||
|
chunk = os.read(read_fd, 4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
written += chunk
|
||||||
|
finally:
|
||||||
|
os.close(read_fd)
|
||||||
|
|
||||||
|
data = json.loads(written.decode())
|
||||||
|
assert data["systemMessage"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_falls_back_to_fd1_when_mcp_server_absent():
|
||||||
|
"""_output() writes to fd 1 directly when mcp_server is not loaded."""
|
||||||
|
read_fd, write_fd = os.pipe()
|
||||||
|
try:
|
||||||
|
orig_fd1 = os.dup(1)
|
||||||
|
os.dup2(write_fd, 1)
|
||||||
|
os.close(write_fd)
|
||||||
|
try:
|
||||||
|
modules_without_mcp = {k: v for k, v in __import__("sys").modules.items()
|
||||||
|
if "mcp_server" not in k}
|
||||||
|
with patch.dict("sys.modules", modules_without_mcp, clear=True):
|
||||||
|
from mempalace.hooks_cli import _output
|
||||||
|
_output({"continue": True})
|
||||||
|
finally:
|
||||||
|
os.dup2(orig_fd1, 1)
|
||||||
|
os.close(orig_fd1)
|
||||||
|
except Exception:
|
||||||
|
os.close(read_fd)
|
||||||
|
raise
|
||||||
|
|
||||||
|
written = b""
|
||||||
|
while True:
|
||||||
|
chunk = os.read(read_fd, 4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
written += chunk
|
||||||
|
os.close(read_fd)
|
||||||
|
|
||||||
|
data = json.loads(written.decode())
|
||||||
|
assert data["continue"] is True
|
||||||
|
|
||||||
|
|
||||||
def test_log_writes_to_hook_log(tmp_path):
|
def test_log_writes_to_hook_log(tmp_path):
|
||||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
_log("test message")
|
_log("test message")
|
||||||
|
|||||||
Reference in New Issue
Block a user