diff --git a/.agents/plugins/marketplace.json b/.agents/plugins/marketplace.json new file mode 100644 index 0000000..58223a0 --- /dev/null +++ b/.agents/plugins/marketplace.json @@ -0,0 +1,20 @@ +{ + "name": "mempalace", + "interface": { + "displayName": "MemPalace" + }, + "plugins": [ + { + "name": "mempalace", + "source": { + "source": "local", + "path": "./.codex-plugin" + }, + "policy": { + "installation": "AVAILABLE", + "authentication": "NONE" + }, + "category": "Coding" + } + ] +} diff --git a/.claude-plugin/.mcp.json b/.claude-plugin/.mcp.json new file mode 100644 index 0000000..b1e81ed --- /dev/null +++ b/.claude-plugin/.mcp.json @@ -0,0 +1,9 @@ +{ + "mempalace": { + "command": "python3", + "args": [ + "-m", + "mempalace.mcp_server" + ] + } +} diff --git a/.claude-plugin/README.md b/.claude-plugin/README.md new file mode 100644 index 0000000..fd98952 --- /dev/null +++ b/.claude-plugin/README.md @@ -0,0 +1,57 @@ +# MemPalace Claude Code Plugin + +A Claude Code plugin that gives your AI a persistent memory system. Mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and 5 guided skills. + +## Prerequisites + +- Python 3.9+ + +## Installation + +### Claude Code Marketplace + +```bash +claude plugin marketplace add milla-jovovich/mempalace +claude plugin install --scope user mempalace +``` + +### Local Clone + +```bash +claude plugin add /path/to/mempalace +``` + +## Post-Install Setup + +After installing the plugin, run the init command to complete setup (pip install, MCP configuration, etc.): + +``` +/mempalace:init +``` + +## Available Slash Commands + +| Command | Description | +|---------|-------------| +| `/mempalace:help` | Show available tools, skills, and architecture | +| `/mempalace:init` | Set up MemPalace -- install, configure MCP, onboard | +| `/mempalace:search` | Search your memories across the palace | +| `/mempalace:mine` | Mine projects and conversations into the palace | +| `/mempalace:status` | Show palace overview -- wings, rooms, drawer counts | + +## Hooks + +MemPalace registers two hooks that run automatically: + +- **Stop** -- Saves conversation context every 15 messages. +- **PreCompact** -- Preserves important memories before context compaction. + +Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger. + +## MCP Server + +The plugin automatically configures a local MCP server with 19 tools for storing, searching, and managing memories. No manual MCP setup is required -- `/mempalace:init` handles everything. + +## Full Documentation + +See the main [README](../README.md) for complete documentation, architecture details, and advanced usage. diff --git a/.claude-plugin/commands/help.md b/.claude-plugin/commands/help.md new file mode 100644 index 0000000..2f56339 --- /dev/null +++ b/.claude-plugin/commands/help.md @@ -0,0 +1,6 @@ +--- +description: Show comprehensive MemPalace help — available skills, MCP tools, CLI commands, hooks, and architecture. +allowed-tools: Bash, Read +--- + +Invoke the generic mempalace skill (using the Skill tool) with the `help` command, then follow its instructions. diff --git a/.claude-plugin/commands/init.md b/.claude-plugin/commands/init.md new file mode 100644 index 0000000..ff27562 --- /dev/null +++ b/.claude-plugin/commands/init.md @@ -0,0 +1,6 @@ +--- +description: Set up MemPalace — install the package, initialize a palace, configure MCP server, and verify everything works. +allowed-tools: Bash, Read, Write, Edit, Glob, Grep +--- + +Invoke the generic mempalace skill (using the Skill tool) with the `init` command, then follow its instructions. diff --git a/.claude-plugin/commands/mine.md b/.claude-plugin/commands/mine.md new file mode 100644 index 0000000..edac2b0 --- /dev/null +++ b/.claude-plugin/commands/mine.md @@ -0,0 +1,7 @@ +--- +description: Mine projects and conversations into the MemPalace. Supports project files, conversation exports, and auto-classification. +argument-hint: Path to project or conversation export to mine. +allowed-tools: Bash, Read, Write, Edit, Glob, Grep +--- + +Invoke the generic mempalace skill (using the Skill tool) with the `mine` command, then follow its instructions. diff --git a/.claude-plugin/commands/search.md b/.claude-plugin/commands/search.md new file mode 100644 index 0000000..9fe8c34 --- /dev/null +++ b/.claude-plugin/commands/search.md @@ -0,0 +1,7 @@ +--- +description: Search your memories across the MemPalace using semantic search with wing/room filtering. +argument-hint: Search query, optionally with wing/room filters. +allowed-tools: Bash, Read +--- + +Invoke the generic mempalace skill (using the Skill tool) with the `search` command, then follow its instructions. diff --git a/.claude-plugin/commands/status.md b/.claude-plugin/commands/status.md new file mode 100644 index 0000000..a87f27b --- /dev/null +++ b/.claude-plugin/commands/status.md @@ -0,0 +1,6 @@ +--- +description: Show the current state of your memory palace — wings, rooms, drawer counts, and suggestions. +allowed-tools: Bash, Read +--- + +Invoke the generic mempalace skill (using the Skill tool) with the `status` command, then follow its instructions. diff --git a/.claude-plugin/hooks/hooks.json b/.claude-plugin/hooks/hooks.json new file mode 100644 index 0000000..f1f0a90 --- /dev/null +++ b/.claude-plugin/hooks/hooks.json @@ -0,0 +1,25 @@ +{ + "description": "MemPalace auto-save and pre-compact hooks", + "hooks": { + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-stop-hook.sh" + } + ] + } + ], + "PreCompact": [ + { + "hooks": [ + { + "type": "command", + "command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-precompact-hook.sh" + } + ] + } + ] + } +} diff --git a/.claude-plugin/hooks/mempal-precompact-hook.sh b/.claude-plugin/hooks/mempal-precompact-hook.sh new file mode 100644 index 0000000..0ac46dd --- /dev/null +++ b/.claude-plugin/hooks/mempal-precompact-hook.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# MemPalace PreCompact Hook — thin wrapper calling Python CLI +# All logic lives in mempalace.hooks_cli for cross-harness extensibility +INPUT=$(cat) +echo "$INPUT" | python3 -m mempalace hook run --hook precompact --harness claude-code diff --git a/.claude-plugin/hooks/mempal-stop-hook.sh b/.claude-plugin/hooks/mempal-stop-hook.sh new file mode 100644 index 0000000..cba3284 --- /dev/null +++ b/.claude-plugin/hooks/mempal-stop-hook.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# MemPalace Stop Hook — thin wrapper calling Python CLI +# All logic lives in mempalace.hooks_cli for cross-harness extensibility +INPUT=$(cat) +echo "$INPUT" | python3 -m mempalace hook run --hook stop --harness claude-code diff --git a/.claude-plugin/marketplace.json b/.claude-plugin/marketplace.json new file mode 100644 index 0000000..6b23ccf --- /dev/null +++ b/.claude-plugin/marketplace.json @@ -0,0 +1,18 @@ +{ + "name": "mempalace", + "owner": { + "name": "milla-jovovich", + "url": "https://github.com/milla-jovovich" + }, + "plugins": [ + { + "name": "mempalace", + "source": "./.claude-plugin", + "description": "AI memory system — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, guided setup.", + "version": "3.0.14", + "author": { + "name": "milla-jovovich" + } + } + ] +} diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json new file mode 100644 index 0000000..fa05a15 --- /dev/null +++ b/.claude-plugin/plugin.json @@ -0,0 +1,29 @@ +{ + "name": "mempalace", + "version": "3.0.14", + "description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.", + "author": { + "name": "milla-jovovich" + }, + "license": "MIT", + "commands": [], + "mcpServers": { + "mempalace": { + "command": "python3", + "args": [ + "-m", + "mempalace.mcp_server" + ] + } + }, + "keywords": [ + "memory", + "ai", + "rag", + "mcp", + "chromadb", + "palace", + "search" + ], + "repository": "https://github.com/milla-jovovich/mempalace" +} diff --git a/.claude-plugin/skills/mempalace/SKILL.md b/.claude-plugin/skills/mempalace/SKILL.md new file mode 100644 index 0000000..ae60fca --- /dev/null +++ b/.claude-plugin/skills/mempalace/SKILL.md @@ -0,0 +1,35 @@ +--- +name: mempalace +description: MemPalace — mine projects and conversations into a searchable memory palace. Use when asked about mempalace, memory palace, mining memories, searching memories, or palace setup. +allowed-tools: Bash, Read, Write, Edit, Glob, Grep +--- + +# MemPalace + +A searchable memory palace for AI — mine projects and conversations, then search them semantically. + +## Prerequisites + +Ensure `mempalace` is installed: + +```bash +mempalace --version +``` + +If not installed: + +```bash +pip install mempalace +``` + +## Usage + +MemPalace provides dynamic instructions via the CLI. To get instructions for any operation: + +```bash +mempalace instructions +``` + +Where `` is one of: `help`, `init`, `mine`, `search`, `status`. + +Run the appropriate instructions command, then follow the returned instructions step by step. diff --git a/.codex-plugin/README.md b/.codex-plugin/README.md new file mode 100644 index 0000000..57dbc34 --- /dev/null +++ b/.codex-plugin/README.md @@ -0,0 +1,75 @@ +# MemPalace - Codex CLI Plugin + +Give your AI a persistent memory -- mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills. + +## Prerequisites + +- Python 3.9+ +- Codex CLI installed and configured +- `pip install mempalace` + +## Installation + +### Local Install + +1. Copy or symlink the `.codex-plugin` directory into your project root: + +```bash +cp -r .codex-plugin /path/to/your/project/.codex-plugin +``` + +2. Verify the plugin is detected: + +```bash +codex --plugins +``` + +3. Initialize your palace: + +```bash +codex /init +``` + +### Git Install + +1. Clone the MemPalace repository: + +```bash +git clone https://github.com/milla-jovovich/mempalace.git +cd mempalace +``` + +2. Install the Python package: + +```bash +pip install -e . +``` + +3. The `.codex-plugin` directory is already in the repo root. Codex CLI will detect it automatically when you run Codex from inside the repository. + +4. Initialize your palace: + +```bash +codex /init +``` + +## Available Skills + +| Skill | Description | +|-------|-------------| +| `/help` | Show available commands and usage tips | +| `/init` | Initialize a new memory palace | +| `/search` | Semantic search across all mined memories | +| `/mine` | Mine a project or conversation into your palace | +| `/status` | Show palace status, room counts, and health | + +## Hooks + +The plugin includes auto-save hooks that run on session stop (every 15 messages) and before context compaction, automatically preserving conversation context into your palace. + +Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger. + +## Support + +- Repository: https://github.com/milla-jovovich/mempalace +- Issues: https://github.com/milla-jovovich/mempalace/issues diff --git a/.codex-plugin/hooks.json b/.codex-plugin/hooks.json new file mode 100644 index 0000000..46f7e66 --- /dev/null +++ b/.codex-plugin/hooks.json @@ -0,0 +1,37 @@ +{ + "hooks": { + "SessionStart": [ + { + "matcher": "*", + "hooks": [ + { + "type": "command", + "command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh session-start" + } + ] + } + ], + "Stop": [ + { + "matcher": "*", + "hooks": [ + { + "type": "command", + "command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh stop" + } + ] + } + ], + "PreCompact": [ + { + "matcher": "*", + "hooks": [ + { + "type": "command", + "command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh precompact" + } + ] + } + ] + } +} diff --git a/.codex-plugin/hooks/mempal-hook.sh b/.codex-plugin/hooks/mempal-hook.sh new file mode 100644 index 0000000..1cc0050 --- /dev/null +++ b/.codex-plugin/hooks/mempal-hook.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +set -euo pipefail +HOOK_NAME="${1:?Usage: mempal-hook.sh }" +INPUT_FILE=$(mktemp) || { echo "Failed to create temp file" >&2; exit 1; } +cat > "$INPUT_FILE" +cat "$INPUT_FILE" | python3 -m mempalace hook run --hook "$HOOK_NAME" --harness codex +EXIT_CODE=$? +rm -f "$INPUT_FILE" 2>/dev/null +exit $EXIT_CODE diff --git a/.codex-plugin/plugin.json b/.codex-plugin/plugin.json new file mode 100644 index 0000000..5784847 --- /dev/null +++ b/.codex-plugin/plugin.json @@ -0,0 +1,52 @@ +{ + "name": "mempalace", + "version": "3.0.14", + "description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.", + "author": { + "name": "milla-jovovich" + }, + "homepage": "https://github.com/milla-jovovich/mempalace", + "repository": "https://github.com/milla-jovovich/mempalace", + "license": "MIT", + "keywords": [ + "memory", + "ai", + "rag", + "mcp", + "chromadb", + "palace", + "search" + ], + "skills": "./skills/", + "hooks": "./hooks.json", + "mcpServers": { + "mempalace": { + "command": "python3", + "args": [ + "-m", + "mempalace.mcp_server" + ] + } + }, + "interface": { + "displayName": "MemPalace", + "shortDescription": "AI memory system for Codex", + "longDescription": "Give your AI a persistent memory — mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills.", + "developerName": "milla-jovovich", + "category": "Coding", + "capabilities": [ + "Interactive", + "Read", + "Write" + ], + "websiteURL": "https://github.com/milla-jovovich/mempalace", + "privacyPolicyURL": "https://github.com/milla-jovovich/mempalace", + "termsOfServiceURL": "https://github.com/milla-jovovich/mempalace", + "defaultPrompt": [ + "Search my memories for recent decisions", + "Mine this project into my memory palace", + "Show my palace status and room counts" + ], + "brandColor": "#7C3AED" + } +} diff --git a/.codex-plugin/skills/help/SKILL.md b/.codex-plugin/skills/help/SKILL.md new file mode 100644 index 0000000..d0e8f43 --- /dev/null +++ b/.codex-plugin/skills/help/SKILL.md @@ -0,0 +1,13 @@ +--- +name: help +description: Show MemPalace help — available commands, usage tips, and getting started guidance. +allowed-tools: Bash, Read +--- + +# MemPalace Help + +Run the following command and follow the returned instructions step by step: + +```bash +mempalace instructions help +``` diff --git a/.codex-plugin/skills/init/SKILL.md b/.codex-plugin/skills/init/SKILL.md new file mode 100644 index 0000000..cc0e2f9 --- /dev/null +++ b/.codex-plugin/skills/init/SKILL.md @@ -0,0 +1,13 @@ +--- +name: init +description: Initialize a new MemPalace — guided setup for your AI memory palace with ChromaDB backend. +allowed-tools: Bash, Read, Write, Edit +--- + +# MemPalace Init + +Run the following command and follow the returned instructions step by step: + +```bash +mempalace instructions init +``` diff --git a/.codex-plugin/skills/mine/SKILL.md b/.codex-plugin/skills/mine/SKILL.md new file mode 100644 index 0000000..1a94e29 --- /dev/null +++ b/.codex-plugin/skills/mine/SKILL.md @@ -0,0 +1,13 @@ +--- +name: mine +description: Mine a project or conversation into your MemPalace — extract and store memories for later retrieval. +allowed-tools: Bash, Read, Glob, Grep +--- + +# MemPalace Mine + +Run the following command and follow the returned instructions step by step: + +```bash +mempalace instructions mine +``` diff --git a/.codex-plugin/skills/search/SKILL.md b/.codex-plugin/skills/search/SKILL.md new file mode 100644 index 0000000..4d5bf4b --- /dev/null +++ b/.codex-plugin/skills/search/SKILL.md @@ -0,0 +1,13 @@ +--- +name: search +description: Search your MemPalace — semantic search across all mined memories, projects, and conversations. +allowed-tools: Bash, Read +--- + +# MemPalace Search + +Run the following command and follow the returned instructions step by step: + +```bash +mempalace instructions search +``` diff --git a/.codex-plugin/skills/status/SKILL.md b/.codex-plugin/skills/status/SKILL.md new file mode 100644 index 0000000..617d3be --- /dev/null +++ b/.codex-plugin/skills/status/SKILL.md @@ -0,0 +1,13 @@ +--- +name: status +description: Show MemPalace status — room counts, storage usage, and palace health. +allowed-tools: Bash, Read +--- + +# MemPalace Status + +Run the following command and follow the returned instructions step by step: + +```bash +mempalace instructions status +``` diff --git a/.github/workflows/bump-plugin-version.yml b/.github/workflows/bump-plugin-version.yml new file mode 100644 index 0000000..0867b3c --- /dev/null +++ b/.github/workflows/bump-plugin-version.yml @@ -0,0 +1,51 @@ +name: Bump Version + +on: + push: + branches: [main] + +jobs: + bump-version: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v6 + + - name: Bump patch version + run: | + CURRENT=$(python3 -c "exec(open('mempalace/version.py').read()); print(__version__)") + IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT" + PATCH=$((PATCH + 1)) + NEW="${MAJOR}.${MINOR}.${PATCH}" + echo "__version__ = \"${NEW}\"" > mempalace/version.py + # Prepend docstring + sed -i '1i"""Single source of truth for the MemPalace package version."""\n' mempalace/version.py + echo "version=$NEW" >> "$GITHUB_OUTPUT" + id: version + + - name: Sync plugin.json + run: | + jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .claude-plugin/plugin.json > tmp.json && mv tmp.json .claude-plugin/plugin.json + + - name: Sync marketplace.json + run: | + jq --arg v "${{ steps.version.outputs.version }}" '.plugins[0].version = $v' .claude-plugin/marketplace.json > tmp.json && mv tmp.json .claude-plugin/marketplace.json + + - name: Sync codex plugin.json + run: | + jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .codex-plugin/plugin.json > tmp.json && mv tmp.json .codex-plugin/plugin.json + + - name: Sync pyproject.toml + run: | + sed -i "s/^version = \".*\"/version = \"${{ steps.version.outputs.version }}\"/" pyproject.toml + + - name: Commit and push + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add mempalace/version.py .claude-plugin/plugin.json .claude-plugin/marketplace.json .codex-plugin/plugin.json pyproject.toml + if ! git diff --staged --quiet; then + git commit -m "chore: bump version to ${{ steps.version.outputs.version }}" + git push + fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ccb15e..302c8e9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [main] jobs: - test: + test-linux: runs-on: ubuntu-latest strategy: matrix: @@ -18,8 +18,27 @@ jobs: with: python-version: ${{ matrix.python-version }} - run: pip install -e ".[dev]" - - run: python -m pytest tests/ -v + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 + test-windows: + runs-on: windows-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.9" + - run: pip install -e ".[dev]" + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 + + test-macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.9" + - run: pip install -e ".[dev]" + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 lint: runs-on: ubuntu-latest steps: @@ -27,6 +46,6 @@ jobs: - uses: actions/setup-python@v6 with: python-version: "3.11" - - run: pip install ruff + - run: pip install "ruff>=0.4.0,<0.5" - run: ruff check . - run: ruff format --check . diff --git a/.gitignore b/.gitignore index 54c2d1b..c8b10cc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ *.pyc .pytest_cache/ mempal.yaml +.a5c/ diff --git a/README.md b/README.md index d05952f..a1a7ccb 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi
-[Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-compression) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server) +[Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-dialect-experimental) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server)
@@ -112,6 +112,17 @@ Three mining modes: **projects** (code and docs), **convos** (conversation expor After the one-time setup (install → init → mine), you don't run MemPalace commands manually. Your AI uses it for you. There are two ways, depending on which AI you use. +### With Claude Code (recommended) + +Native marketplace install: + +```bash +claude plugin marketplace add milla-jovovich/mempalace +claude plugin install --scope user mempalace +``` + +Restart Claude Code, then type `/skills` to verify "mempalace" appears. + ### With Claude, ChatGPT, Cursor, Gemini (MCP-compatible tools) ```bash @@ -439,6 +450,11 @@ Letta charges $20–200/mo for agent-managed memory. MemPalace does it with a wi ## MCP Server ```bash +# Via plugin (recommended) +claude plugin marketplace add milla-jovovich/mempalace +claude plugin install --scope user mempalace + +# Or manually claude mcp add mempalace -- python -m mempalace.mcp_server ``` @@ -509,6 +525,8 @@ Two hooks for Claude Code that automatically save memories during work: } ``` +**Optional auto-ingest:** Set the `MEMPAL_DIR` environment variable to a directory path and the hooks will automatically run `mempalace mine` on that directory during each save trigger (background on stop, synchronous on precompact). + --- ## Benchmarks diff --git a/mempalace/__init__.py b/mempalace/__init__.py index d7a138d..78d760b 100644 --- a/mempalace/__init__.py +++ b/mempalace/__init__.py @@ -1,6 +1,21 @@ """MemPalace — Give your AI a memory. No API key required.""" -from .cli import main -from .version import __version__ +import logging +import os +import platform + +from .cli import main # noqa: E402 +from .version import __version__ # noqa: E402 + +# ChromaDB 0.6.x ships a Posthog telemetry client whose capture() signature is +# incompatible with the bundled posthog library, producing noisy stderr warnings +# on every client operation ("Failed to send telemetry event … capture() takes +# 1 positional argument but 3 were given"). Silence just that logger. +logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL) + +# ONNX Runtime's CoreML provider segfaults during vector queries on Apple Silicon. +# Force CPU execution unless the user has explicitly set a preference. +if platform.machine() == "arm64" and platform.system() == "Darwin": + os.environ.setdefault("ORT_DISABLE_COREML", "1") __all__ = ["main", "__version__"] diff --git a/mempalace/cli.py b/mempalace/cli.py index 1599b08..0a24abf 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -226,6 +226,20 @@ def cmd_repair(args): print(f"\n{'=' * 55}\n") +def cmd_hook(args): + """Run hook logic: reads JSON from stdin, outputs JSON to stdout.""" + from .hooks_cli import run_hook + + run_hook(hook_name=args.hook, harness=args.harness) + + +def cmd_instructions(args): + """Output skill instructions to stdout.""" + from .instructions_cli import run_instructions + + run_instructions(name=args.name) + + def cmd_compress(args): """Compress drawers in a wing using AAAK Dialect.""" import chromadb @@ -451,6 +465,35 @@ def main(): help="Only split files containing at least N sessions (default: 2)", ) + # hook + p_hook = sub.add_parser( + "hook", + help="Run hook logic (reads JSON from stdin, outputs JSON to stdout)", + ) + hook_sub = p_hook.add_subparsers(dest="hook_action") + p_hook_run = hook_sub.add_parser("run", help="Execute a hook") + p_hook_run.add_argument( + "--hook", + required=True, + choices=["session-start", "stop", "precompact"], + help="Hook name to run", + ) + p_hook_run.add_argument( + "--harness", + required=True, + choices=["claude-code", "codex"], + help="Harness type (determines stdin JSON format)", + ) + + # instructions + p_instructions = sub.add_parser( + "instructions", + help="Output skill instructions to stdout", + ) + instructions_sub = p_instructions.add_subparsers(dest="instructions_name") + for instr_name in ["init", "search", "mine", "help", "status"]: + instructions_sub.add_parser(instr_name, help=f"Output {instr_name} instructions") + # repair sub.add_parser( "repair", @@ -466,6 +509,23 @@ def main(): parser.print_help() return + # Handle two-level subcommands + if args.command == "hook": + if not getattr(args, "hook_action", None): + p_hook.print_help() + return + cmd_hook(args) + return + + if args.command == "instructions": + name = getattr(args, "instructions_name", None) + if not name: + p_instructions.print_help() + return + args.name = name + cmd_instructions(args) + return + dispatch = { "init": cmd_init, "mine": cmd_mine, diff --git a/mempalace/entity_registry.py b/mempalace/entity_registry.py index 24fef0a..2a4ad8d 100644 --- a/mempalace/entity_registry.py +++ b/mempalace/entity_registry.py @@ -309,7 +309,7 @@ class EntityRegistry: def save(self): self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text(json.dumps(self._data, indent=2)) + self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8") @staticmethod def _empty() -> dict: diff --git a/mempalace/hooks_cli.py b/mempalace/hooks_cli.py new file mode 100644 index 0000000..3f3fc09 --- /dev/null +++ b/mempalace/hooks_cli.py @@ -0,0 +1,226 @@ +""" +Hook logic for MemPalace — Python implementation of session-start, stop, and precompact hooks. + +Reads JSON from stdin, outputs JSON to stdout. +Supported hooks: session-start, stop, precompact +Supported harnesses: claude-code, codex (extensible to cursor, gemini, etc.) +""" + +import json +import os +import re +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +SAVE_INTERVAL = 15 +STATE_DIR = Path.home() / ".mempalace" / "hook_state" + +STOP_BLOCK_REASON = ( + "AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code " + "from this session to your memory system. Organize into appropriate " + "categories. Use verbatim quotes where possible. Continue conversation " + "after saving." +) + +PRECOMPACT_BLOCK_REASON = ( + "COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and " + "important context from this session to your memory system. Be thorough " + "\u2014 after compaction, detailed context will be lost. Organize into " + "appropriate categories. Use verbatim quotes where possible. Save " + "everything, then allow compaction to proceed." +) + + +def _sanitize_session_id(session_id: str) -> str: + """Only allow alnum, dash, underscore to prevent path traversal.""" + sanitized = re.sub(r"[^a-zA-Z0-9_-]", "", session_id) + return sanitized or "unknown" + + +def _count_human_messages(transcript_path: str) -> int: + """Count human messages in a JSONL transcript, skipping command-messages.""" + path = Path(transcript_path).expanduser() + if not path.is_file(): + return 0 + count = 0 + try: + with open(path, encoding="utf-8", errors="replace") as f: + for line in f: + try: + entry = json.loads(line) + msg = entry.get("message", {}) + if isinstance(msg, dict) and msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str): + if "" in content: + continue + elif isinstance(content, list): + text = " ".join( + b.get("text", "") for b in content if isinstance(b, dict) + ) + if "" in text: + continue + count += 1 + except (json.JSONDecodeError, AttributeError): + pass + except OSError: + return 0 + return count + + +def _log(message: str): + """Append to hook state log file.""" + try: + STATE_DIR.mkdir(parents=True, exist_ok=True) + log_path = STATE_DIR / "hook.log" + timestamp = datetime.now().strftime("%H:%M:%S") + with open(log_path, "a") as f: + f.write(f"[{timestamp}] {message}\n") + except OSError: + pass + + +def _output(data: dict): + """Print JSON to stdout with consistent formatting (pretty-printed).""" + print(json.dumps(data, indent=2, ensure_ascii=False)) + + +def _maybe_auto_ingest(): + """If MEMPAL_DIR is set and exists, run mempalace mine in background.""" + mempal_dir = os.environ.get("MEMPAL_DIR", "") + if mempal_dir and os.path.isdir(mempal_dir): + try: + log_path = STATE_DIR / "hook.log" + with open(log_path, "a") as log_f: + subprocess.Popen( + [sys.executable, "-m", "mempalace", "mine", mempal_dir], + stdout=log_f, + stderr=log_f, + ) + except OSError: + pass + + +SUPPORTED_HARNESSES = {"claude-code", "codex"} + + +def _parse_harness_input(data: dict, harness: str) -> dict: + """Parse stdin JSON according to the harness type.""" + if harness not in SUPPORTED_HARNESSES: + print(f"Unknown harness: {harness}", file=sys.stderr) + sys.exit(1) + return { + "session_id": _sanitize_session_id(str(data.get("session_id", "unknown"))), + "stop_hook_active": data.get("stop_hook_active", False), + "transcript_path": str(data.get("transcript_path", "")), + } + + +def hook_stop(data: dict, harness: str): + """Stop hook: block every N messages for auto-save.""" + parsed = _parse_harness_input(data, harness) + session_id = parsed["session_id"] + stop_hook_active = parsed["stop_hook_active"] + transcript_path = parsed["transcript_path"] + + # If already in a save cycle, let through (infinite-loop prevention) + if str(stop_hook_active).lower() in ("true", "1", "yes"): + _output({}) + return + + # Count human messages + exchange_count = _count_human_messages(transcript_path) + + # Track last save point + STATE_DIR.mkdir(parents=True, exist_ok=True) + last_save_file = STATE_DIR / f"{session_id}_last_save" + last_save = 0 + if last_save_file.is_file(): + try: + last_save = int(last_save_file.read_text().strip()) + except (ValueError, OSError): + last_save = 0 + + since_last = exchange_count - last_save + + _log(f"Session {session_id}: {exchange_count} exchanges, {since_last} since last save") + + if since_last >= SAVE_INTERVAL and exchange_count > 0: + # Update last save point + try: + last_save_file.write_text(str(exchange_count), encoding="utf-8") + except OSError: + pass + + _log(f"TRIGGERING SAVE at exchange {exchange_count}") + + # Optional: auto-ingest if MEMPAL_DIR is set + _maybe_auto_ingest() + + _output({"decision": "block", "reason": STOP_BLOCK_REASON}) + else: + _output({}) + + +def hook_session_start(data: dict, harness: str): + """Session start hook: initialize session tracking state.""" + parsed = _parse_harness_input(data, harness) + session_id = parsed["session_id"] + + _log(f"SESSION START for session {session_id}") + + # Initialize session state directory + STATE_DIR.mkdir(parents=True, exist_ok=True) + + # Pass through — no blocking on session start + _output({}) + + +def hook_precompact(data: dict, harness: str): + """Precompact hook: always block with comprehensive save instruction.""" + parsed = _parse_harness_input(data, harness) + session_id = parsed["session_id"] + + _log(f"PRE-COMPACT triggered for session {session_id}") + + # Optional: auto-ingest synchronously before compaction (so memories land first) + mempal_dir = os.environ.get("MEMPAL_DIR", "") + if mempal_dir and os.path.isdir(mempal_dir): + try: + log_path = STATE_DIR / "hook.log" + with open(log_path, "a") as log_f: + subprocess.run( + [sys.executable, "-m", "mempalace", "mine", mempal_dir], + stdout=log_f, + stderr=log_f, + timeout=60, + ) + except OSError: + pass + + # Always block -- compaction = save everything + _output({"decision": "block", "reason": PRECOMPACT_BLOCK_REASON}) + + +def run_hook(hook_name: str, harness: str): + """Main entry point: read stdin JSON, dispatch to hook handler.""" + try: + data = json.load(sys.stdin) + except (json.JSONDecodeError, EOFError): + _log("WARNING: Failed to parse stdin JSON, proceeding with empty data") + data = {} + + hooks = { + "session-start": hook_session_start, + "stop": hook_stop, + "precompact": hook_precompact, + } + + handler = hooks.get(hook_name) + if handler is None: + print(f"Unknown hook: {hook_name}", file=sys.stderr) + sys.exit(1) + + handler(data, harness) diff --git a/mempalace/instructions/help.md b/mempalace/instructions/help.md new file mode 100644 index 0000000..f18c1de --- /dev/null +++ b/mempalace/instructions/help.md @@ -0,0 +1,105 @@ +# MemPalace + +AI memory system. Store everything, find anything. Local, free, no API key. + +--- + +## Slash Commands + +| Command | Description | +|----------------------|--------------------------------| +| /mempalace:init | Install and set up MemPalace | +| /mempalace:search | Search your memories | +| /mempalace:mine | Mine projects and conversations| +| /mempalace:status | Palace overview and stats | +| /mempalace:help | This help message | + +--- + +## MCP Tools (19) + +### Palace (read) +- mempalace_status -- Palace status and stats +- mempalace_list_wings -- List all wings +- mempalace_list_rooms -- List rooms in a wing +- mempalace_get_taxonomy -- Get the full taxonomy tree +- mempalace_search -- Search memories by query +- mempalace_check_duplicate -- Check if a memory already exists +- mempalace_get_aaak_spec -- Get the AAAK specification + +### Palace (write) +- mempalace_add_drawer -- Add a new memory (drawer) +- mempalace_delete_drawer -- Delete a memory (drawer) + +### Knowledge Graph +- mempalace_kg_query -- Query the knowledge graph +- mempalace_kg_add -- Add a knowledge graph entry +- mempalace_kg_invalidate -- Invalidate a knowledge graph entry +- mempalace_kg_timeline -- View knowledge graph timeline +- mempalace_kg_stats -- Knowledge graph statistics + +### Navigation +- mempalace_traverse -- Traverse the palace structure +- mempalace_find_tunnels -- Find cross-wing connections +- mempalace_graph_stats -- Graph connectivity statistics + +### Agent Diary +- mempalace_diary_write -- Write a diary entry +- mempalace_diary_read -- Read diary entries + +--- + +## CLI Commands + + mempalace init Initialize a new palace + mempalace mine Mine a project (default mode) + mempalace mine --mode convos Mine conversation exports + mempalace search "query" Search your memories + mempalace split Split large transcript files + mempalace wake-up Load palace into context + mempalace compress Compress palace storage + mempalace status Show palace status + mempalace repair Rebuild vector index + mempalace hook run Run hook logic (for harness integration) + mempalace instructions Output skill instructions + +--- + +## Auto-Save Hooks + +- Stop hook -- Automatically saves memories every 15 messages. Counts human + messages in the session transcript (skipping command-messages). When the + threshold is reached, blocks the AI with a save instruction. Uses + ~/.mempalace/hook_state/ to track save points per session. If + stop_hook_active is true, passes through to prevent infinite loops. + +- PreCompact hook -- Emergency save before context compaction. Always blocks + with a comprehensive save instruction because compaction means the AI is + about to lose detailed context. + +Hooks read JSON from stdin and output JSON to stdout. They can be invoked via: + + echo '{"session_id":"abc","stop_hook_active":false,"transcript_path":"..."}' | mempalace hook run --hook stop --harness claude-code + +--- + +## Architecture + + Wings (projects/people) + +-- Rooms (topics) + +-- Closets (summaries) + +-- Drawers (verbatim memories) + + Halls connect rooms within a wing. + Tunnels connect rooms across wings. + +The palace is stored locally using ChromaDB for vector search and SQLite for +metadata. No cloud services or API keys required. + +--- + +## Getting Started + +1. /mempalace:init -- Set up your palace +2. /mempalace:mine -- Mine a project or conversation +3. /mempalace:search -- Find what you stored diff --git a/mempalace/instructions/init.md b/mempalace/instructions/init.md new file mode 100644 index 0000000..40fe8fc --- /dev/null +++ b/mempalace/instructions/init.md @@ -0,0 +1,69 @@ +# MemPalace Init + +Guide the user through a complete MemPalace setup. Follow each step in order, +stopping to report errors and attempt remediation before proceeding. + +## Step 1: Check Python version + +Run `python3 --version` (or `python --version` on Windows) and confirm the +version is 3.9 or higher. If Python is not found or the version is too old, +tell the user they need Python 3.9+ installed and stop. + +## Step 2: Check if mempalace is already installed + +Run `pip show mempalace` to see if the package is already present. If it is, +report the installed version and skip to Step 4. + +## Step 3: Install mempalace + +Run `pip install mempalace`. + +### Error handling -- pip failures + +If `pip install mempalace` fails, try these fallbacks in order: + +1. Try `pip3 install mempalace` +2. Try `python -m pip install mempalace` (or `python3 -m pip install mempalace`) +3. If the error mentions missing build tools or compilation failures (commonly + from chromadb or its native dependencies): + - On Linux/macOS: suggest `sudo apt-get install build-essential python3-dev` + (Debian/Ubuntu) or `xcode-select --install` (macOS) + - On Windows: suggest installing Microsoft C++ Build Tools from + https://visualstudio.microsoft.com/visual-cpp-build-tools/ + - Then retry the install command +4. If all attempts fail, report the error clearly and stop. + +## Step 4: Ask for project directory + +Ask the user which project directory they want to initialize with MemPalace. +Offer the current working directory as the default. Wait for their response +before continuing. + +## Step 5: Initialize the palace + +Run `mempalace init ` where `` is the directory from Step 4. + +If this fails, report the error and stop. + +## Step 6: Configure MCP server + +Run the following command to register the MemPalace MCP server with Claude: + + claude mcp add mempalace -- python -m mempalace.mcp_server + +If this fails, report the error but continue to the next step (MCP +configuration can be done manually later). + +## Step 7: Verify installation + +Run `mempalace status` and confirm the output shows a healthy palace. + +If the command fails or reports errors, walk the user through troubleshooting +based on the output. + +## Step 8: Show next steps + +Tell the user setup is complete and suggest these next actions: + +- Use /mempalace:mine to start adding data to their palace +- Use /mempalace:search to query their palace and retrieve stored knowledge diff --git a/mempalace/instructions/mine.md b/mempalace/instructions/mine.md new file mode 100644 index 0000000..ec8c250 --- /dev/null +++ b/mempalace/instructions/mine.md @@ -0,0 +1,64 @@ +# MemPalace Mine + +When the user invokes this skill, follow these steps: + +## 1. Ask what to mine + +Ask the user what they want to mine and where the source data is located. +Clarify: +- Is it a project directory (code, docs, notes)? +- Is it conversation exports (Claude, ChatGPT, Slack)? +- Do they want auto-classification (decisions, milestones, problems)? + +## 2. Choose the mining mode + +There are three mining modes: + +### Project mining + + mempalace mine + +Mines code files, documentation, and notes from a project directory. + +### Conversation mining + + mempalace mine --mode convos + +Mines conversation exports from Claude, ChatGPT, or Slack into the palace. + +### General extraction (auto-classify) + + mempalace mine --mode convos --extract general + +Auto-classifies mined content into decisions, milestones, and problems. + +## 3. Optionally split mega-files first + +If the source directory contains very large files, suggest splitting them +before mining: + + mempalace split [--dry-run] + +Use --dry-run first to preview what will be split without making changes. + +## 4. Optionally tag with a wing + +If the user wants to organize mined content under a specific wing, add the +--wing flag: + + mempalace mine --wing + +## 5. Show progress and results + +Run the selected mining command and display progress as it executes. After +completion, summarize the results including: +- Number of items mined +- Categories or classifications applied +- Any warnings or skipped files + +## 6. Suggest next steps + +After mining completes, suggest the user try: +- /mempalace:search -- search the newly mined content +- /mempalace:status -- check the current state of their palace +- Mine more data from additional sources diff --git a/mempalace/instructions/search.md b/mempalace/instructions/search.md new file mode 100644 index 0000000..0b6a813 --- /dev/null +++ b/mempalace/instructions/search.md @@ -0,0 +1,57 @@ +# MemPalace Search + +When the user wants to search their MemPalace memories, follow these steps: + +## 1. Parse the Search Query + +Extract the core search intent from the user's message. Identify any explicit +or implicit filters: +- Wing -- a top-level category (e.g., "work", "personal", "research") +- Room -- a sub-category within a wing +- Keywords / semantic query -- the actual search terms + +## 2. Determine Wing/Room Filters + +If the user mentions a specific domain, topic area, or context, map it to the +appropriate wing and/or room. If unsure, omit filters to search globally. You +can discover the taxonomy first if needed. + +## 3. Use MCP Tools (Preferred) + +If MCP tools are available, use them in this priority order: + +- mempalace_search(query, wing, room) -- Primary search tool. Pass the semantic + query and any wing/room filters. +- mempalace_list_wings -- Discover all available wings. Use when the user asks + what categories exist or you need to resolve a wing name. +- mempalace_list_rooms(wing) -- List rooms within a specific wing. Use to help + the user navigate or to resolve a room name. +- mempalace_get_taxonomy -- Retrieve the full wing/room/drawer tree. Use when + the user wants an overview of their entire memory structure. +- mempalace_traverse(room) -- Walk the knowledge graph starting from a room. + Use when the user wants to explore connections and related memories. +- mempalace_find_tunnels(wing1, wing2) -- Find cross-wing connections (tunnels) + between two wings. Use when the user asks about relationships between + different knowledge domains. + +## 4. CLI Fallback + +If MCP tools are not available, fall back to the CLI: + + mempalace search "query" [--wing X] [--room Y] + +## 5. Present Results + +When presenting search results: +- Always include source attribution: wing, room, and drawer for each result +- Show relevance or similarity scores if available +- Group results by wing/room when returning multiple hits +- Quote or summarize the memory content clearly + +## 6. Offer Next Steps + +After presenting results, offer the user options to go deeper: +- Drill deeper -- search within a specific room or narrow the query +- Traverse -- explore the knowledge graph from a related room +- Check tunnels -- look for cross-wing connections if the topic spans domains +- Browse taxonomy -- show the full structure for manual exploration diff --git a/mempalace/instructions/status.md b/mempalace/instructions/status.md new file mode 100644 index 0000000..ceb902b --- /dev/null +++ b/mempalace/instructions/status.md @@ -0,0 +1,49 @@ +# MemPalace Status + +Display the current state of the user's memory palace. + +## Step 1: Gather Palace Status + +Check if MCP tools are available (look for mempalace_status in available tools). + +- If MCP is available: Call the mempalace_status tool to retrieve palace state. +- If MCP is not available: Run the CLI command: mempalace status + +## Step 2: Display Wing/Room/Drawer Counts + +Present the palace structure counts clearly: +- Number of wings +- Number of rooms +- Number of drawers +- Total memories stored + +Keep the output concise -- use a brief summary format, not verbose tables. + +## Step 3: Knowledge Graph Stats (MCP only) + +If MCP tools are available, also call: +- mempalace_kg_stats -- for a knowledge graph overview (triple count, entity + count, relationship types) +- mempalace_graph_stats -- for connectivity information (connected components, + average connections per entity) + +Present these alongside the palace counts in a unified summary. + +## Step 4: Suggest Next Actions + +Based on the current state, suggest one relevant action: + +- Empty palace (zero memories): Suggest "Try /mempalace:mine to add data from + files, URLs, or text." +- Has data but no knowledge graph (memories exist but KG stats show zero + triples): Suggest "Consider adding knowledge graph triples for richer + queries." +- Healthy palace (has memories and KG data): Suggest "Use /mempalace:search to + query your memories." + +## Output Style + +- Be concise and informative -- aim for a quick glance, not a report. +- Use short labels and numbers, not prose paragraphs. +- If any step fails or a tool is unavailable, note it briefly and continue + with what is available. diff --git a/mempalace/instructions_cli.py b/mempalace/instructions_cli.py new file mode 100644 index 0000000..239d721 --- /dev/null +++ b/mempalace/instructions_cli.py @@ -0,0 +1,28 @@ +""" +Instruction text output for MemPalace CLI commands. + +Each instruction lives as a .md file in the instructions/ directory +inside the package. The CLI reads and prints the file content. +""" + +import sys +from pathlib import Path + +INSTRUCTIONS_DIR = Path(__file__).parent / "instructions" + +AVAILABLE = ["init", "search", "mine", "help", "status"] + + +def run_instructions(name: str): + """Read and print the instruction .md file for the given name.""" + if name not in AVAILABLE: + print(f"Unknown instructions: {name}", file=sys.stderr) + print(f"Available: {', '.join(sorted(AVAILABLE))}", file=sys.stderr) + sys.exit(1) + + md_path = INSTRUCTIONS_DIR / f"{name}.md" + if not md_path.is_file(): + print(f"Instructions file not found: {md_path}", file=sys.stderr) + sys.exit(1) + + print(md_path.read_text()) diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 5bf56dd..c4a570a 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -2,7 +2,7 @@ """ MemPalace MCP Server — read/write palace access for Claude Code ================================================================ -Install: claude mcp add mempalace -- python -m mempalace.mcp_server +Install: claude mcp add mempalace -- python -m mempalace.mcp_server [--palace /path/to/palace] Tools (read): mempalace_status — total drawers, wing/room breakdown @@ -17,6 +17,8 @@ Tools (write): mempalace_delete_drawer — remove a drawer by ID """ +import argparse +import os import sys import json import logging @@ -32,21 +34,50 @@ import chromadb from .knowledge_graph import KnowledgeGraph -_kg = KnowledgeGraph() - logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr) logger = logging.getLogger("mempalace_mcp") + +def _parse_args(): + parser = argparse.ArgumentParser(description="MemPalace MCP Server") + parser.add_argument( + "--palace", + metavar="PATH", + help="Path to the palace directory (overrides config file and env var)", + ) + args, unknown = parser.parse_known_args() + if unknown: + logger.debug("Ignoring unknown args: %s", unknown) + return args + + +_args = _parse_args() + +if _args.palace: + os.environ["MEMPALACE_PALACE_PATH"] = os.path.abspath(_args.palace) + _config = MempalaceConfig() +if _args.palace: + _kg = KnowledgeGraph(db_path=os.path.join(_config.palace_path, "knowledge_graph.sqlite3")) +else: + _kg = KnowledgeGraph() + + +_client_cache = None +_collection_cache = None def _get_collection(create=False): - """Return the ChromaDB collection, or None on failure.""" + """Return the ChromaDB collection, caching the client between calls.""" + global _client_cache, _collection_cache try: - client = chromadb.PersistentClient(path=_config.palace_path) + if _client_cache is None: + _client_cache = chromadb.PersistentClient(path=_config.palace_path) if create: - return client.get_or_create_collection(_config.collection_name) - return client.get_collection(_config.collection_name) + _collection_cache = _client_cache.get_or_create_collection(_config.collection_name) + elif _collection_cache is None: + _collection_cache = _client_cache.get_collection(_config.collection_name) + return _collection_cache except Exception: return None @@ -270,19 +301,18 @@ def tool_add_drawer( if not col: return _no_palace() - # Duplicate check - dup = tool_check_duplicate(content, threshold=0.9) - if dup.get("is_duplicate"): - return { - "success": False, - "reason": "duplicate", - "matches": dup["matches"], - } + drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((content[:100] + datetime.now().isoformat()).encode()).hexdigest()[:16]}" + # Idempotency: if the deterministic ID already exists, return success as a no-op. + try: + existing = col.get(ids=[drawer_id]) + if existing and existing["ids"]: + return {"success": True, "reason": "already_exists", "drawer_id": drawer_id} + except Exception: + pass try: - col.add( + col.upsert( ids=[drawer_id], documents=[content], metadatas=[ diff --git a/mempalace/miner.py b/mempalace/miner.py index 7b4e949..66fbe03 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -403,10 +403,22 @@ def get_collection(palace_path: str): def file_already_mined(collection, source_file: str) -> bool: - """Fast check: has this file been filed before?""" + """Fast check: has this file been filed before and is unchanged? + + Compares the stored mtime in drawer metadata against the file's current + mtime. Returns False (needs re-mining) when the file has been modified + since it was last mined, or when no mtime was stored. + """ try: results = collection.get(where={"source_file": source_file}, limit=1) - return len(results.get("ids", [])) > 0 + if not results.get("ids"): + return False + stored_meta = results["metadatas"][0] if results.get("metadatas") else {} + stored_mtime = stored_meta.get("source_mtime") + if stored_mtime is None: + return False + current_mtime = os.path.getmtime(source_file) + return float(stored_mtime) == current_mtime except Exception: return False @@ -417,24 +429,26 @@ def add_drawer( """Add one drawer to the palace.""" drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" try: - collection.add( + metadata = { + "wing": wing, + "room": room, + "source_file": source_file, + "chunk_index": chunk_index, + "added_by": agent, + "filed_at": datetime.now().isoformat(), + } + # Store file mtime so we can detect modifications later. + try: + metadata["source_mtime"] = os.path.getmtime(source_file) + except OSError: + pass + collection.upsert( documents=[content], ids=[drawer_id], - metadatas=[ - { - "wing": wing, - "room": room, - "source_file": source_file, - "chunk_index": chunk_index, - "added_by": agent, - "filed_at": datetime.now().isoformat(), - } - ], + metadatas=[metadata], ) return True - except Exception as e: - if "already exists" in str(e).lower() or "duplicate" in str(e).lower(): - return False + except Exception: raise @@ -451,29 +465,29 @@ def process_file( rooms: list, agent: str, dry_run: bool, -) -> int: - """Read, chunk, route, and file one file. Returns drawer count.""" +) -> tuple: + """Read, chunk, route, and file one file. Returns (drawer_count, room_name).""" # Skip if already filed source_file = str(filepath) if not dry_run and file_already_mined(collection, source_file): - return 0 + return 0, None try: content = filepath.read_text(encoding="utf-8", errors="replace") except OSError: - return 0 + return 0, None content = content.strip() if len(content) < MIN_CHUNK_SIZE: - return 0 + return 0, None room = detect_room(filepath, content, rooms, project_path) chunks = chunk_text(content, source_file) if dry_run: print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)") - return len(chunks) + return len(chunks), room drawers_added = 0 for chunk in chunks: @@ -489,7 +503,7 @@ def process_file( if added: drawers_added += 1 - return drawers_added + return drawers_added, room # ============================================================================= @@ -608,7 +622,7 @@ def mine( room_counts = defaultdict(int) for i, filepath in enumerate(files, 1): - drawers = process_file( + drawers, room = process_file( filepath=filepath, project_path=project_path, collection=collection, @@ -621,7 +635,6 @@ def mine( files_skipped += 1 else: total_drawers += drawers - room = detect_room(filepath, "", rooms, project_path) room_counts[room] += 1 if not dry_run: print(f" ✓ [{i:4}/{len(files)}] {filepath.name[:50]:50} +{drawers}") diff --git a/mempalace/onboarding.py b/mempalace/onboarding.py index f578d91..70f7b54 100644 --- a/mempalace/onboarding.py +++ b/mempalace/onboarding.py @@ -312,7 +312,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines)) + (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8") # Critical facts bootstrap (pre-palace — before any mining) facts_lines = [ @@ -359,7 +359,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines)) + (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8") def run_onboarding( diff --git a/mempalace/split_mega_files.py b/mempalace/split_mega_files.py index 80bbae4..ae801df 100644 --- a/mempalace/split_mega_files.py +++ b/mempalace/split_mega_files.py @@ -219,7 +219,7 @@ def split_file(filepath, output_dir, dry_run=False): if dry_run: print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)") else: - out_path.write_text("".join(chunk)) + out_path.write_text("".join(chunk), encoding="utf-8") print(f" ✓ {name} ({len(chunk)} lines)") written.append(out_path) diff --git a/mempalace/version.py b/mempalace/version.py index 08a910f..e56289e 100644 --- a/mempalace/version.py +++ b/mempalace/version.py @@ -1,3 +1,3 @@ """Single source of truth for the MemPalace package version.""" -__version__ = "3.0.0" +__version__ = "3.0.14" diff --git a/pyproject.toml b/pyproject.toml index 4862873..7b201da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mempalace" -version = "3.0.0" +version = "3.0.14" description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required." readme = "README.md" requires-python = ">=3.9" @@ -38,11 +38,11 @@ Repository = "https://github.com/milla-jovovich/mempalace" mempalace = "mempalace:main" [project.optional-dependencies] -dev = ["pytest>=7.0", "ruff>=0.4.0"] +dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"] spellcheck = ["autocorrect>=2.0"] [dependency-groups] -dev = ["pytest>=7.0", "ruff>=0.4.0"] +dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"] [build-system] requires = ["hatchling"] @@ -64,3 +64,21 @@ quote-style = "double" [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = ["."] +addopts = "-m 'not benchmark and not slow and not stress'" +markers = [ + "benchmark: scale/performance benchmark tests", + "slow: tests that take more than 30 seconds", + "stress: destructive scale tests (100K+ drawers)", +] + +[tool.coverage.run] +source = ["mempalace"] + +[tool.coverage.report] +fail_under = 85 +show_missing = true +exclude_lines = [ + "if __name__", + "pragma: no cover", +] diff --git a/tests/benchmarks/README.md b/tests/benchmarks/README.md new file mode 100644 index 0000000..965afee --- /dev/null +++ b/tests/benchmarks/README.md @@ -0,0 +1,138 @@ +# MemPalace Scale Benchmark Suite + +106 tests that benchmark mempalace at scale to validate real-world performance limits. + +## Why + +MemPalace has strong academic scores (96.6% R@5 on LongMemEval) but no empirical data on how it behaves at scale. Key unknowns: + +- `tool_status()` loads ALL metadata into memory — at what palace size does this OOM? +- `PersistentClient` is re-instantiated on every MCP call — what's the overhead? +- Modified files are never re-ingested — what's the skip-check cost at scale? +- How does query latency degrade as the palace grows from 1K to 100K drawers? +- Does wing/room filtering actually improve retrieval, and by how much? +- At what per-room drawer count does recall break regardless of filtering? + +This suite finds those answers. + +## Quick Start + +```bash +# Fast smoke test (~2 min) +uv run pytest tests/benchmarks/ -v --bench-scale=small -m "benchmark and not slow" + +# Full small scale (~35 min) +uv run pytest tests/benchmarks/ -v --bench-scale=small + +# Medium scale with JSON report +uv run pytest tests/benchmarks/ -v --bench-scale=medium --bench-report=results.json + +# Stress test (local only, very slow) +uv run pytest tests/benchmarks/ -v --bench-scale=stress -m stress +``` + +## Scale Levels + +| Level | Drawers | Wings | Rooms/Wing | KG Triples | Use case | +|---------|---------|-------|------------|------------|---------------------| +| small | 1,000 | 3 | 5 | 200 | CI, quick checks | +| medium | 10,000 | 8 | 12 | 2,000 | Pre-release testing | +| large | 50,000 | 15 | 20 | 10,000 | Scale limit finding | +| stress | 100,000 | 25 | 30 | 50,000 | Breaking point | + +## Test Modules + +### Critical Path + +| File | What it tests | +|------|--------------| +| `test_mcp_bench.py` | MCP tool response times, unbounded metadata fetch, client re-instantiation overhead | +| `test_chromadb_stress.py` | ChromaDB breaking point, query degradation curve, batch vs sequential insert | +| `test_memory_profile.py` | RSS/heap growth over repeated operations, leak detection | + +### Performance Baselines + +| File | What it tests | +|------|--------------| +| `test_ingest_bench.py` | Mining throughput (files/sec, drawers/sec), peak RSS, chunking speed, re-ingest skip overhead | +| `test_search_bench.py` | Query latency vs palace size, recall@k with planted needles, concurrent queries, n_results scaling | + +### Architectural Validation + +| File | What it tests | +|------|--------------| +| `test_palace_boost.py` | Retrieval improvement from wing/room filtering at different scales | +| `test_recall_threshold.py` | Per-room recall ceiling — isolates embedding model limit with all drawers in one bucket | +| `test_knowledge_graph_bench.py` | Triple insertion rate, temporal query accuracy, SQLite concurrent access | +| `test_layers_bench.py` | MemoryStack wake-up cost, Layer1 unbounded fetch, token budget compliance | + +## Architecture + +``` +tests/benchmarks/ + conftest.py # --bench-scale / --bench-report CLI options, fixtures, markers + data_generator.py # Deterministic data factory (seeded RNG, planted needles) + report.py # JSON report writer + regression checker + test_*.py # 9 test modules (106 tests total) +``` + +### Data Generator + +`PalaceDataGenerator(seed=42, scale="small")` produces deterministic, realistic test data: + +- **`generate_project_tree()`** — writes real files + `mempalace.yaml` for `mine()` to ingest +- **`populate_palace_directly()`** — bypasses mining, inserts directly into ChromaDB (10-100x faster for search/MCP benchmarks) +- **`generate_kg_triples()`** — entity-relationship triples with temporal validity +- **`generate_search_queries()`** — queries with known-good answers for recall measurement + +**Planted needles**: Unique identifiable content (e.g., `NEEDLE_0042: PostgreSQL vacuum autovacuum threshold...`) seeded into specific wings/rooms. Search queries target these needles, enabling recall@k measurement without an LLM judge. + +### JSON Reports + +When run with `--bench-report=path.json`, produces machine-readable output: + +```json +{ + "timestamp": "2026-04-07T...", + "git_sha": "abc123", + "scale": "small", + "system": {"os": "linux", "cpu_count": 8}, + "results": { + "mcp_status": {"latency_ms_at_1000": 45.2, "rss_delta_mb_at_5000": 12.3}, + "search": {"avg_latency_ms_at_5000": 23.1, "recall_at_5": 0.92}, + "chromadb_insert": {"sequential_ms": 8500, "batched_ms": 1200, "speedup_ratio": 7.1} + } +} +``` + +### Regression Detection + +```python +from tests.benchmarks.report import check_regression + +regressions = check_regression("current.json", "baseline.json", threshold=0.2) +# Returns list of metric descriptions that degraded beyond 20% +``` + +## CI Integration + +The GitHub Actions workflow runs benchmarks on PRs at small scale: + +```yaml +benchmark: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + # Runs: pytest tests/benchmarks/ -m "benchmark and not stress and not slow" --bench-scale=small +``` + +Existing unit tests are isolated with `--ignore=tests/benchmarks`. + +## Markers + +- `@pytest.mark.benchmark` — all benchmark tests +- `@pytest.mark.slow` — tests taking >30s even at small scale +- `@pytest.mark.stress` — tests that should only run at large/stress scale + +## Dependencies + +Only one new dependency beyond the existing dev stack: `psutil` (for cross-platform RSS measurement). `tracemalloc` and `resource` are stdlib. diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..0ef3255 --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1 @@ +# MemPalace scale benchmark suite diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py new file mode 100644 index 0000000..bd3f201 --- /dev/null +++ b/tests/benchmarks/conftest.py @@ -0,0 +1,144 @@ +"""Benchmark-specific pytest configuration, fixtures, and CLI options.""" + +import json +import os +import tempfile + +import pytest + + +SCALE_OPTIONS = ["small", "medium", "large", "stress"] + + +def pytest_addoption(parser): + parser.addoption( + "--bench-scale", + default="small", + choices=SCALE_OPTIONS, + help="Scale level for benchmark tests: small (1K), medium (10K), large (50K), stress (100K)", + ) + parser.addoption( + "--bench-report", + default=None, + help="Path for JSON benchmark report output", + ) + + +@pytest.fixture(scope="session") +def bench_scale(request): + """The configured benchmark scale level.""" + return request.config.getoption("--bench-scale") + + +@pytest.fixture(scope="session") +def bench_report_path(request): + """Path for JSON report output, or None.""" + return request.config.getoption("--bench-report") + + +@pytest.fixture +def palace_dir(tmp_path): + """Isolated palace directory for a single test.""" + p = tmp_path / "palace" + p.mkdir() + return str(p) + + +@pytest.fixture +def kg_db(tmp_path): + """Isolated KG SQLite path for a single test.""" + return str(tmp_path / "test_kg.sqlite3") + + +@pytest.fixture +def config_dir(tmp_path): + """Isolated config directory for monkeypatching MempalaceConfig.""" + d = tmp_path / "config" + d.mkdir() + config = {"palace_path": str(tmp_path / "palace"), "collection_name": "mempalace_drawers"} + with open(d / "config.json", "w") as f: + json.dump(config, f) + return str(d) + + +@pytest.fixture +def project_dir(tmp_path): + """Temporary project directory for mining tests.""" + d = tmp_path / "project" + d.mkdir() + return d + + +# ── Session-scoped result collector ────────────────────────────────────── + + +class BenchmarkResults: + """Collect benchmark metrics across all tests in a session.""" + + def __init__(self): + self.results = {} + + def record(self, category: str, metric: str, value): + if category not in self.results: + self.results[category] = {} + self.results[category][metric] = value + + +@pytest.fixture(scope="session") +def bench_results(): + """Session-scoped results collector shared by all benchmark tests.""" + return BenchmarkResults() + + +def pytest_terminal_summary(terminalreporter, config): + """Write JSON benchmark report after all tests complete.""" + report_path = config.getoption("--bench-report", default=None) + if not report_path: + return + + # Collect results written by individual tests via record_metric() + import platform + import subprocess + + try: + git_sha = subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], text=True, stderr=subprocess.DEVNULL + ).strip() + except Exception: + git_sha = "unknown" + + try: + import chromadb + + chromadb_version = chromadb.__version__ + except Exception: + chromadb_version = "unknown" + + report = { + "timestamp": __import__("datetime").datetime.now().isoformat(), + "git_sha": git_sha, + "python_version": platform.python_version(), + "chromadb_version": chromadb_version, + "scale": config.getoption("--bench-scale", default="small"), + "system": { + "os": platform.system().lower(), + "cpu_count": os.cpu_count(), + "platform": platform.platform(), + }, + "results": {}, + } + + # Read results from the temp file written by record_metric() + results_file = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json") + if os.path.exists(results_file): + try: + with open(results_file) as f: + report["results"] = json.load(f) + os.unlink(results_file) + except Exception: + pass + + os.makedirs(os.path.dirname(os.path.abspath(report_path)), exist_ok=True) + with open(report_path, "w") as f: + json.dump(report, f, indent=2) + terminalreporter.write_line(f"\nBenchmark report written to: {report_path}") diff --git a/tests/benchmarks/data_generator.py b/tests/benchmarks/data_generator.py new file mode 100644 index 0000000..0184239 --- /dev/null +++ b/tests/benchmarks/data_generator.py @@ -0,0 +1,568 @@ +""" +Deterministic data factory for MemPalace scale benchmarks. + +Generates realistic project files, conversations, and KG triples at +configurable scale levels. All randomness uses seeded RNG for reproducibility. + +Planted "needle" drawers enable recall measurement without an LLM judge. +""" + +import hashlib +import os +import random +from datetime import datetime, timedelta +from pathlib import Path + +import chromadb +import yaml + + +# ── Scale configurations ───────────────────────────────────────────────── + +SCALE_CONFIGS = { + "small": { + "drawers": 1_000, + "wings": 3, + "rooms_per_wing": 5, + "kg_entities": 50, + "kg_triples": 200, + "needles": 20, + "search_queries": 20, + }, + "medium": { + "drawers": 10_000, + "wings": 8, + "rooms_per_wing": 12, + "kg_entities": 200, + "kg_triples": 2_000, + "needles": 50, + "search_queries": 50, + }, + "large": { + "drawers": 50_000, + "wings": 15, + "rooms_per_wing": 20, + "kg_entities": 500, + "kg_triples": 10_000, + "needles": 100, + "search_queries": 100, + }, + "stress": { + "drawers": 100_000, + "wings": 25, + "rooms_per_wing": 30, + "kg_entities": 1_000, + "kg_triples": 50_000, + "needles": 200, + "search_queries": 200, + }, +} + +# ── Vocabulary banks for realistic content ─────────────────────────────── + +WING_NAMES = [ + "webapp", + "backend_api", + "mobile_app", + "data_pipeline", + "ml_platform", + "devops", + "auth_service", + "payments", + "analytics", + "docs_site", + "cli_tool", + "dashboard", + "notification_service", + "search_engine", + "user_mgmt", + "inventory", + "reporting", + "testing_infra", + "monitoring", + "email_service", + "chat_bot", + "file_storage", + "scheduler", + "gateway", + "marketplace", +] + +ROOM_NAMES = [ + "backend", + "frontend", + "api", + "database", + "auth", + "tests", + "docs", + "config", + "deployment", + "models", + "views", + "controllers", + "middleware", + "utils", + "schemas", + "migrations", + "fixtures", + "scripts", + "styles", + "components", + "hooks", + "services", + "routes", + "templates", + "static", + "media", + "logging", + "cache", + "queue", + "workers", +] + +TECH_TERMS = [ + "authentication", + "authorization", + "middleware", + "endpoint", + "REST API", + "GraphQL", + "WebSocket", + "database migration", + "ORM", + "query optimization", + "caching strategy", + "load balancer", + "rate limiting", + "pagination", + "serialization", + "validation", + "error handling", + "logging framework", + "monitoring", + "deployment pipeline", + "CI/CD", + "containerization", + "microservice", + "event sourcing", + "message queue", + "pub/sub", + "connection pooling", + "session management", + "token refresh", + "CORS", + "SSL termination", + "health check", + "circuit breaker", + "retry logic", + "batch processing", + "stream processing", + "data pipeline", + "ETL", + "feature flag", + "A/B testing", + "blue-green deployment", + "canary release", +] + +CODE_SNIPPETS = [ + "def process_request(data):\n validated = schema.validate(data)\n result = handler.execute(validated)\n return Response(result, status=200)\n", + "class UserRepository:\n def __init__(self, db):\n self.db = db\n def find_by_id(self, user_id):\n return self.db.query(User).filter(User.id == user_id).first()\n", + "async def fetch_data(url, timeout=30):\n async with aiohttp.ClientSession() as session:\n async with session.get(url, timeout=timeout) as resp:\n return await resp.json()\n", + "const handleSubmit = async (formData) => {\n try {\n const response = await api.post('/users', formData);\n dispatch({ type: 'USER_CREATED', payload: response.data });\n } catch (error) {\n setError(error.message);\n }\n};\n", + "SELECT u.name, COUNT(o.id) as order_count\nFROM users u\nLEFT JOIN orders o ON u.id = o.user_id\nWHERE u.created_at > '2025-01-01'\nGROUP BY u.name\nHAVING COUNT(o.id) > 5\nORDER BY order_count DESC;\n", +] + +PROSE_TEMPLATES = [ + "The {component} module handles {task}. It was refactored in {month} to improve {quality}. Key design decision: {decision}.", + "Bug report: {component} fails when {condition}. Root cause: {cause}. Fixed by {fix}. Regression test added in {test_file}.", + "Architecture decision: switched from {old_tech} to {new_tech} for {reason}. Migration completed {date}. Performance improved by {percent}%.", + "Meeting notes: discussed {topic} with {person}. Agreed to {action}. Deadline: {deadline}. Follow-up: {followup}.", + "Feature spec: {feature_name} allows users to {capability}. Dependencies: {deps}. Estimated effort: {effort} days.", +] + +ENTITY_NAMES = [ + "Alice", + "Bob", + "Carol", + "Dave", + "Eve", + "Frank", + "Grace", + "Heidi", + "Ivan", + "Judy", + "Karl", + "Linda", + "Mike", + "Nina", + "Oscar", + "Pat", + "Quinn", + "Rita", + "Steve", + "Tina", + "Ursula", + "Victor", + "Wendy", + "Xander", +] + +ENTITY_TYPES = ["person", "project", "tool", "concept", "team", "service"] + +PREDICATES = [ + "works_on", + "manages", + "reports_to", + "collaborates_with", + "created", + "maintains", + "uses", + "depends_on", + "replaced", + "reviewed", + "deployed", + "tested", + "documented", + "mentors", + "leads", + "contributes_to", +] + + +class PalaceDataGenerator: + """Generate deterministic, realistic test data at configurable scale.""" + + def __init__(self, seed=42, scale="small"): + self.rng = random.Random(seed) + self.scale = scale + self.cfg = SCALE_CONFIGS[scale] + self.wings = WING_NAMES[: self.cfg["wings"]] + self.rooms_by_wing = {} + for wing in self.wings: + n = self.cfg["rooms_per_wing"] + rooms = self.rng.sample(ROOM_NAMES, min(n, len(ROOM_NAMES))) + self.rooms_by_wing[wing] = rooms + # Planted needles for recall measurement + self.needles = [] + self._generate_needles() + + def _generate_needles(self): + """Create unique needle content for recall testing.""" + topics = [ + "Fibonacci sequence optimization uses memoization with O(n) space complexity", + "PostgreSQL vacuum autovacuum threshold set to 50 percent for table users", + "Redis cluster failover timeout configured at 30 seconds with sentinel monitoring", + "Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization", + "GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds", + "JWT token rotation policy requires refresh every 15 minutes with sliding window", + "Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each", + "Docker multi-stage build reduces image size from 1.2GB to 180MB for production", + "Apache Kafka consumer group rebalance timeout set to 45 seconds", + "MongoDB change streams resume token persisted every 100 operations", + "gRPC streaming uses bidirectional flow control with 64KB window size", + "Prometheus alerting rule fires when p99 latency exceeds 500ms for 5 minutes", + "Terraform state locking uses DynamoDB with consistent reads enabled", + "Nginx rate limiting configured at 100 requests per second with burst of 50", + "SQLAlchemy connection pool size set to 20 with max overflow of 10 connections", + "React concurrent mode uses startTransition for non-urgent state updates", + "AWS Lambda cold start mitigation uses provisioned concurrency of 10 instances", + "Git bisect automated with custom test script for regression hunting", + "OpenTelemetry trace sampling rate set to 10 percent in production environment", + "Celery worker prefetch multiplier set to 1 for fair task distribution", + ] + for i in range(self.cfg["needles"]): + topic = topics[i % len(topics)] + wing = self.rng.choice(self.wings) + room = self.rng.choice(self.rooms_by_wing[wing]) + needle_id = f"NEEDLE_{i:04d}" + content = f"{needle_id}: {topic}. This is a unique planted needle for recall benchmarking at scale." + self.needles.append( + { + "id": needle_id, + "content": content, + "wing": wing, + "room": room, + "query": topic.split(" uses ")[0] + if " uses " in topic + else topic.split(" set to ")[0] + if " set to " in topic + else topic[:60], + } + ) + + def _random_text(self, min_chars=600, max_chars=900): + """Generate a random text block of realistic content.""" + parts = [] + total = 0 + target = self.rng.randint(min_chars, max_chars) + while total < target: + choice = self.rng.random() + if choice < 0.3: + text = self.rng.choice(CODE_SNIPPETS) + elif choice < 0.7: + template = self.rng.choice(PROSE_TEMPLATES) + text = template.format( + component=self.rng.choice(ROOM_NAMES), + task=self.rng.choice(TECH_TERMS), + month=self.rng.choice(["January", "February", "March", "April", "May"]), + quality=self.rng.choice( + ["performance", "readability", "test coverage", "latency"] + ), + decision=self.rng.choice(TECH_TERMS), + condition=self.rng.choice(TECH_TERMS) + " is null", + cause=self.rng.choice(["race condition", "null pointer", "timeout", "OOM"]), + fix="adding " + self.rng.choice(TECH_TERMS), + test_file=f"test_{self.rng.choice(ROOM_NAMES)}.py", + old_tech=self.rng.choice(["MySQL", "Flask", "REST", "Jenkins"]), + new_tech=self.rng.choice( + ["PostgreSQL", "FastAPI", "GraphQL", "GitHub Actions"] + ), + reason=self.rng.choice(TECH_TERMS), + date=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}", + percent=self.rng.randint(10, 80), + topic=self.rng.choice(TECH_TERMS), + person=self.rng.choice(ENTITY_NAMES), + action=self.rng.choice(["refactor", "migrate", "optimize", "test"]), + deadline=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}", + followup=self.rng.choice(TECH_TERMS), + feature_name=self.rng.choice(TECH_TERMS), + capability=self.rng.choice(TECH_TERMS), + deps=", ".join(self.rng.sample(TECH_TERMS, 2)), + effort=self.rng.randint(1, 15), + ) + else: + words = self.rng.sample(TECH_TERMS, min(5, len(TECH_TERMS))) + text = ( + " ".join(words) + + ". " + + self.rng.choice(TECH_TERMS) + + " implementation details follow.\n" + ) + parts.append(text) + total += len(text) + return "\n".join(parts)[:max_chars] + + # ── Project tree generation (for mine() tests) ─────────────────────── + + def generate_project_tree(self, base_path, wing=None, rooms=None, n_files=50): + """ + Write realistic project files + mempalace.yaml to base_path. + + Returns the project path suitable for passing to mine(). + """ + base = Path(base_path) + base.mkdir(parents=True, exist_ok=True) + wing = wing or self.rng.choice(self.wings) + rooms = rooms or self.rooms_by_wing.get(wing, ["general"]) + + # Write mempalace.yaml + room_defs = [{"name": r, "description": f"{r} code and docs"} for r in rooms] + with open(base / "mempalace.yaml", "w") as f: + yaml.dump({"wing": wing, "rooms": room_defs}, f) + + # Write files distributed across room directories + files_written = 0 + for i in range(n_files): + room = rooms[i % len(rooms)] + room_dir = base / room + room_dir.mkdir(parents=True, exist_ok=True) + + ext = self.rng.choice([".py", ".js", ".md", ".ts", ".yaml"]) + filename = f"file_{i:04d}{ext}" + content = self._random_text(400, 2000) + (room_dir / filename).write_text(content, encoding="utf-8") + files_written += 1 + + return str(base), wing, rooms, files_written + + # ── Conversation file generation (for mine_convos() tests) ─────────── + + def generate_conversation_files(self, base_path, wing=None, n_files=20): + """Write conversation transcript files for convo_miner tests.""" + base = Path(base_path) + base.mkdir(parents=True, exist_ok=True) + wing = wing or self.rng.choice(self.wings) + + for i in range(n_files): + lines = [] + n_exchanges = self.rng.randint(5, 20) + for j in range(n_exchanges): + user_msg = f"> User: {self.rng.choice(TECH_TERMS)}? How does {self.rng.choice(TECH_TERMS)} work with {self.rng.choice(TECH_TERMS)}?" + ai_msg = self._random_text(200, 600) + lines.append(user_msg) + lines.append(ai_msg) + lines.append("") + + (base / f"convo_{i:04d}.txt").write_text("\n".join(lines), encoding="utf-8") + + return str(base), wing + + # ── Direct palace population (bypasses mining for speed) ───────────── + + def populate_palace_directly(self, palace_path, n_drawers=None, include_needles=True): + """ + Insert drawers directly into ChromaDB, bypassing the mining pipeline. + + Much faster than mining for benchmarks that only care about + search/MCP behavior on a pre-populated palace. + + Returns (client, collection, needle_info). + """ + n_drawers = n_drawers or self.cfg["drawers"] + os.makedirs(palace_path, exist_ok=True) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_or_create_collection("mempalace_drawers") + + batch_size = 500 + docs = [] + ids = [] + metas = [] + + # Insert needles first + needle_info = [] + if include_needles: + for needle in self.needles: + needle_id = f"drawer_{needle['wing']}_{needle['room']}_{hashlib.md5(needle['id'].encode()).hexdigest()[:16]}" + docs.append(needle["content"]) + ids.append(needle_id) + metas.append( + { + "wing": needle["wing"], + "room": needle["room"], + "source_file": f"needle_{needle['id']}.txt", + "chunk_index": 0, + "added_by": "benchmark", + "filed_at": datetime.now().isoformat(), + } + ) + needle_info.append( + { + "id": needle_id, + "query": needle["query"], + "wing": needle["wing"], + "room": needle["room"], + } + ) + + # Fill remaining drawers with realistic content + remaining = n_drawers - len(docs) + for i in range(remaining): + wing = self.wings[i % len(self.wings)] + rooms = self.rooms_by_wing[wing] + room = rooms[i % len(rooms)] + content = self._random_text(400, 800) + drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(f'gen_{i}'.encode()).hexdigest()[:16]}" + + docs.append(content) + ids.append(drawer_id) + metas.append( + { + "wing": wing, + "room": room, + "source_file": f"generated_{i:06d}.txt", + "chunk_index": i % 10, + "added_by": "benchmark", + "filed_at": datetime.now().isoformat(), + } + ) + + # Flush in batches + if len(docs) >= batch_size: + col.add(documents=docs, ids=ids, metadatas=metas) + docs, ids, metas = [], [], [] + + # Flush remainder + if docs: + col.add(documents=docs, ids=ids, metadatas=metas) + + return client, col, needle_info + + # ── KG triple generation ───────────────────────────────────────────── + + def generate_kg_triples(self, n_entities=None, n_triples=None): + """ + Generate realistic entity-relationship triples. + + Returns (entities, triples) where: + entities = [(name, type), ...] + triples = [(subject, predicate, object, valid_from, valid_to), ...] + """ + n_entities = n_entities or self.cfg["kg_entities"] + n_triples = n_triples or self.cfg["kg_triples"] + + # Generate entities + entities = [] + entity_names = [] + for i in range(n_entities): + if i < len(ENTITY_NAMES): + name = ENTITY_NAMES[i] + else: + name = f"Entity_{i:04d}" + etype = self.rng.choice(ENTITY_TYPES) + entities.append((name, etype)) + entity_names.append(name) + + # Generate triples + triples = [] + base_date = datetime(2024, 1, 1) + for i in range(n_triples): + subject = self.rng.choice(entity_names) + obj = self.rng.choice(entity_names) + while obj == subject: + obj = self.rng.choice(entity_names) + predicate = self.rng.choice(PREDICATES) + days_offset = self.rng.randint(0, 730) + valid_from = (base_date + timedelta(days=days_offset)).strftime("%Y-%m-%d") + # 30% chance of having a valid_to + valid_to = None + if self.rng.random() < 0.3: + end_offset = self.rng.randint(30, 365) + valid_to = (base_date + timedelta(days=days_offset + end_offset)).strftime( + "%Y-%m-%d" + ) + triples.append((subject, predicate, obj, valid_from, valid_to)) + + return entities, triples + + # ── Search query generation ────────────────────────────────────────── + + def generate_search_queries(self, n_queries=None): + """ + Generate search queries with expected results. + + Returns list of {"query": str, "expected_wing": str|None, "expected_room": str|None, "is_needle": bool}. + Needle queries have known-good answers for recall measurement. + """ + n_queries = n_queries or self.cfg["search_queries"] + queries = [] + + # Half are needle queries (known-good answers) + n_needle = min(n_queries // 2, len(self.needles)) + for needle in self.needles[:n_needle]: + queries.append( + { + "query": needle["query"], + "expected_wing": needle["wing"], + "expected_room": needle["room"], + "needle_id": needle["id"], + "is_needle": True, + } + ) + + # Other half are generic queries (measure latency, not recall) + n_generic = n_queries - n_needle + for _ in range(n_generic): + queries.append( + { + "query": self.rng.choice(TECH_TERMS) + " " + self.rng.choice(TECH_TERMS), + "expected_wing": None, + "expected_room": None, + "needle_id": None, + "is_needle": False, + } + ) + + self.rng.shuffle(queries) + return queries diff --git a/tests/benchmarks/report.py b/tests/benchmarks/report.py new file mode 100644 index 0000000..61ac937 --- /dev/null +++ b/tests/benchmarks/report.py @@ -0,0 +1,117 @@ +""" +Benchmark report utilities — JSON output and regression detection. + +Each test records metrics via record_metric(). At session end, the +conftest.py pytest_terminal_summary hook writes the collected results. +""" + +import json +import os +import tempfile + + +RESULTS_FILE = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json") + + +def record_metric(category: str, metric: str, value): + """Append a metric to the session results file (JSON on disk).""" + results = {} + if os.path.exists(RESULTS_FILE): + try: + with open(RESULTS_FILE) as f: + results = json.load(f) + except (json.JSONDecodeError, OSError): + results = {} + + if category not in results: + results[category] = {} + results[category][metric] = value + + with open(RESULTS_FILE, "w") as f: + json.dump(results, f, indent=2) + + +def check_regression(current_report: str, baseline_report: str, threshold: float = 0.2): + """ + Compare current benchmark results against a baseline. + + Returns a list of regression descriptions. Empty list = no regressions. + + threshold: fractional degradation allowed (0.2 = 20% worse is OK). + """ + with open(current_report) as f: + current = json.load(f) + with open(baseline_report) as f: + baseline = json.load(f) + + regressions = [] + # Keywords for metric direction — checked in order, first match wins. + # "improvement" is checked before "latency" so that composite names + # like "latency_improvement_pct" are classified correctly. + _higher_is_better_kw = [ + "improvement", + "recall", + "throughput", + "per_sec", + "files_per_sec", + "drawers_per_sec", + "triples_per_sec", + "speedup", + ] + _higher_is_worse_kw = [ + "latency", + "rss", + "memory", + "oom", + "lock_failures", + "elapsed", + "p50_ms", + "p95_ms", + "p99_ms", + "rss_delta_mb", + "peak_rss_mb", + "errors", + "failures", + ] + + def _metric_direction(name: str) -> str: + """Return 'higher_better', 'higher_worse', or 'unknown'.""" + low = name.lower() + for kw in _higher_is_better_kw: + if kw in low: + return "higher_better" + for kw in _higher_is_worse_kw: + if kw in low: + return "higher_worse" + return "unknown" + + for category in baseline.get("results", {}): + if category not in current.get("results", {}): + continue + for metric, base_val in baseline["results"][category].items(): + if metric not in current["results"][category]: + continue + curr_val = current["results"][category][metric] + if not isinstance(base_val, (int, float)) or not isinstance(curr_val, (int, float)): + continue + if base_val == 0: + continue + + direction = _metric_direction(metric) + + if direction == "higher_worse": + # Higher is worse — check if current exceeds baseline by threshold + if curr_val > base_val * (1 + threshold): + pct = ((curr_val - base_val) / base_val) * 100 + regressions.append( + f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)" + ) + elif direction == "higher_better": + # Lower is worse — check if current is below baseline by threshold + if curr_val < base_val * (1 - threshold): + pct = ((curr_val - base_val) / base_val) * 100 + regressions.append( + f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)" + ) + + return regressions diff --git a/tests/benchmarks/test_chromadb_stress.py b/tests/benchmarks/test_chromadb_stress.py new file mode 100644 index 0000000..1a77529 --- /dev/null +++ b/tests/benchmarks/test_chromadb_stress.py @@ -0,0 +1,206 @@ +""" +ChromaDB stress tests — find the breaking point. + +Tests the raw ChromaDB patterns used by mempalace to determine: + - At what collection size does col.get(include=["metadatas"]) become dangerous? + - How does query latency degrade as collection grows? + - How much faster is batched insertion vs sequential? +""" + +import os +import time + +import chromadb +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +def _get_rss_mb(): + try: + import psutil + + return psutil.Process().memory_info().rss / (1024 * 1024) + except ImportError: + import resource + import platform + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +@pytest.mark.benchmark +class TestGetAllMetadatasOOM: + """ + The specific pattern causing finding #3: + col.get(include=["metadatas"]) with NO limit. + + Measures RSS growth to find when this becomes dangerous. + """ + + SIZES = [1_000, 2_500, 5_000, 10_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_get_all_metadatas_rss(self, n_drawers, tmp_path, bench_scale): + """RSS growth from fetching all metadata at once.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + + rss_before = _get_rss_mb() + start = time.perf_counter() + all_meta = col.get(include=["metadatas"])["metadatas"] + elapsed_ms = (time.perf_counter() - start) * 1000 + rss_after = _get_rss_mb() + + assert len(all_meta) == n_drawers + rss_delta = rss_after - rss_before + + record_metric("chromadb_get_all", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2)) + record_metric("chromadb_get_all", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) + + +@pytest.mark.benchmark +class TestQueryDegradation: + """Measure query latency as collection grows.""" + + SIZES = [1_000, 2_500, 5_000, 10_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_query_latency_at_size(self, n_drawers, tmp_path, bench_scale): + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + + queries = [ + "authentication middleware optimization", + "database connection pooling strategy", + "error handling retry logic", + "deployment pipeline configuration", + "load balancer health check", + ] + + latencies = [] + for q in queries: + start = time.perf_counter() + results = col.query(query_texts=[q], n_results=5, include=["documents", "distances"]) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + assert results["documents"][0] # got results + + avg_ms = sum(latencies) / len(latencies) + p95_ms = sorted(latencies)[int(len(latencies) * 0.95)] + + record_metric("chromadb_query", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1)) + record_metric("chromadb_query", f"p95_latency_ms_at_{n_drawers}", round(p95_ms, 1)) + + +@pytest.mark.benchmark +class TestBulkInsertPerformance: + """Compare batch insertion vs sequential add_drawer pattern.""" + + def test_sequential_vs_batched(self, tmp_path): + """The current miner uses single-document add(). How much faster is batching?""" + n_docs = 500 + gen = PalaceDataGenerator(seed=42) + + # Generate content + contents = [gen._random_text(400, 800) for _ in range(n_docs)] + + # Sequential insertion (mimics add_drawer pattern) + palace_seq = str(tmp_path / "seq") + os.makedirs(palace_seq) + client_seq = chromadb.PersistentClient(path=palace_seq) + col_seq = client_seq.get_or_create_collection("mempalace_drawers") + + start = time.perf_counter() + for i, content in enumerate(contents): + col_seq.add( + documents=[content], + ids=[f"seq_{i}"], + metadatas=[{"wing": "test", "room": "bench", "chunk_index": i}], + ) + sequential_ms = (time.perf_counter() - start) * 1000 + + # Batched insertion + palace_batch = str(tmp_path / "batch") + os.makedirs(palace_batch) + client_batch = chromadb.PersistentClient(path=palace_batch) + col_batch = client_batch.get_or_create_collection("mempalace_drawers") + + batch_size = 100 + start = time.perf_counter() + for batch_start in range(0, n_docs, batch_size): + batch_end = min(batch_start + batch_size, n_docs) + batch_docs = contents[batch_start:batch_end] + batch_ids = [f"batch_{i}" for i in range(batch_start, batch_end)] + batch_metas = [ + {"wing": "test", "room": "bench", "chunk_index": i} + for i in range(batch_start, batch_end) + ] + col_batch.add(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) + batched_ms = (time.perf_counter() - start) * 1000 + + speedup = sequential_ms / max(batched_ms, 0.01) + + assert col_seq.count() == n_docs + assert col_batch.count() == n_docs + + record_metric("chromadb_insert", "sequential_ms", round(sequential_ms, 1)) + record_metric("chromadb_insert", "batched_ms", round(batched_ms, 1)) + record_metric("chromadb_insert", "speedup_ratio", round(speedup, 2)) + record_metric("chromadb_insert", "n_docs", n_docs) + record_metric("chromadb_insert", "batch_size", batch_size) + + +@pytest.mark.benchmark +@pytest.mark.slow +class TestMaxCollectionSize: + """Incrementally grow collection to find practical limits.""" + + def test_incremental_growth(self, tmp_path, bench_scale): + """Add drawers in batches, measure latency per batch.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + cfg = gen.cfg + target = min(cfg["drawers"], 10_000) # cap at 10K for this test + + palace_path = str(tmp_path / "palace") + os.makedirs(palace_path) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_or_create_collection("mempalace_drawers") + + batch_size = 500 + batch_times = [] + total_inserted = 0 + + for batch_num in range(0, target, batch_size): + n = min(batch_size, target - batch_num) + docs = [gen._random_text(400, 800) for _ in range(n)] + ids = [f"growth_{batch_num + i}" for i in range(n)] + metas = [ + {"wing": gen.wings[i % len(gen.wings)], "room": "bench", "chunk_index": i} + for i in range(batch_num, batch_num + n) + ] + + start = time.perf_counter() + col.add(documents=docs, ids=ids, metadatas=metas) + batch_ms = (time.perf_counter() - start) * 1000 + total_inserted += n + batch_times.append({"at_size": total_inserted, "batch_ms": round(batch_ms, 1)}) + + assert col.count() == total_inserted + + # Record first and last batch times to show degradation + record_metric("chromadb_growth", "first_batch_ms", batch_times[0]["batch_ms"]) + record_metric("chromadb_growth", "last_batch_ms", batch_times[-1]["batch_ms"]) + record_metric("chromadb_growth", "total_inserted", total_inserted) + record_metric("chromadb_growth", "batch_times", batch_times) diff --git a/tests/benchmarks/test_ingest_bench.py b/tests/benchmarks/test_ingest_bench.py new file mode 100644 index 0000000..2b4ea5b --- /dev/null +++ b/tests/benchmarks/test_ingest_bench.py @@ -0,0 +1,169 @@ +""" +Ingestion throughput benchmarks. + +Measures mining performance at scale: + - Files/sec and drawers/sec through the full mine() pipeline + - Peak RSS during mining + - Chunking throughput isolated from ChromaDB + - Re-ingest skip overhead (finding #11: file_already_mined check) +""" + +import time + +import chromadb +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +def _get_rss_mb(): + try: + import psutil + + return psutil.Process().memory_info().rss / (1024 * 1024) + except ImportError: + import resource + import platform + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +@pytest.mark.benchmark +class TestMineThroughput: + """Measure the full mine() pipeline throughput.""" + + @pytest.mark.parametrize("n_files", [20, 50, 100]) + def test_mine_files_per_second(self, n_files, tmp_path, bench_scale): + """End-to-end mining throughput: generate files, mine, count drawers.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + project_path, wing, rooms, files_written = gen.generate_project_tree( + tmp_path / "project", n_files=n_files + ) + palace_path = str(tmp_path / "palace") + + from mempalace.miner import mine + + start = time.perf_counter() + mine(project_path, palace_path) + elapsed = time.perf_counter() - start + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + drawer_count = col.count() + + files_per_sec = files_written / max(elapsed, 0.001) + drawers_per_sec = drawer_count / max(elapsed, 0.001) + + record_metric("ingest", f"files_per_sec_at_{n_files}", round(files_per_sec, 1)) + record_metric("ingest", f"drawers_per_sec_at_{n_files}", round(drawers_per_sec, 1)) + record_metric("ingest", f"elapsed_sec_at_{n_files}", round(elapsed, 2)) + record_metric("ingest", f"drawers_created_at_{n_files}", drawer_count) + + def test_mine_peak_rss(self, tmp_path, bench_scale): + """Track peak RSS during a mining run.""" + import threading + + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + project_path, wing, rooms, files_written = gen.generate_project_tree( + tmp_path / "project", n_files=100 + ) + palace_path = str(tmp_path / "palace") + + from mempalace.miner import mine + + rss_samples = [] + stop_sampling = threading.Event() + + def sample_rss(): + while not stop_sampling.is_set(): + rss_samples.append(_get_rss_mb()) + stop_sampling.wait(0.1) + + sampler = threading.Thread(target=sample_rss, daemon=True) + sampler.start() + + rss_before = _get_rss_mb() + mine(project_path, palace_path) + stop_sampling.set() + sampler.join(timeout=1) + + peak_rss = max(rss_samples) if rss_samples else _get_rss_mb() + rss_delta = peak_rss - rss_before + + record_metric("ingest", "peak_rss_mb", round(peak_rss, 1)) + record_metric("ingest", "rss_delta_mb", round(rss_delta, 1)) + + +@pytest.mark.benchmark +class TestChunkThroughput: + """Isolate chunking performance from ChromaDB insertion.""" + + @pytest.mark.parametrize("content_size_kb", [1, 10, 100]) + def test_chunk_text_throughput(self, content_size_kb): + """Measure chunk_text speed for different content sizes.""" + from mempalace.miner import chunk_text + + gen = PalaceDataGenerator(seed=42) + # Generate content of target size + content = gen._random_text(content_size_kb * 500, content_size_kb * 1200) + # Pad to approximate target KB + while len(content) < content_size_kb * 1024: + content += "\n" + gen._random_text(200, 500) + + n_iterations = 50 + start = time.perf_counter() + total_chunks = 0 + for _ in range(n_iterations): + chunks = chunk_text(content, "bench_file.py") + total_chunks += len(chunks) + elapsed = time.perf_counter() - start + + chunks_per_sec = total_chunks / max(elapsed, 0.001) + kb_per_sec = (len(content) * n_iterations / 1024) / max(elapsed, 0.001) + + record_metric( + "chunking", f"chunks_per_sec_at_{content_size_kb}kb", round(chunks_per_sec, 1) + ) + record_metric("chunking", f"kb_per_sec_at_{content_size_kb}kb", round(kb_per_sec, 1)) + + +@pytest.mark.benchmark +class TestReingestSkipOverhead: + """Finding #11: file_already_mined() check overhead at scale.""" + + def test_skip_check_cost(self, tmp_path): + """Mine files, then re-mine — measure cost of skip checks.""" + gen = PalaceDataGenerator(seed=42, scale="small") + project_path, wing, rooms, files_written = gen.generate_project_tree( + tmp_path / "project", n_files=50 + ) + palace_path = str(tmp_path / "palace") + + from mempalace.miner import mine + + # First mine + mine(project_path, palace_path) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + initial_count = col.count() + + # Re-mine (all files should be skipped) + start = time.perf_counter() + mine(project_path, palace_path) + skip_elapsed = time.perf_counter() - start + + # Verify no new drawers added + final_count = col.count() + assert final_count == initial_count, "Re-mine should not add new drawers" + + record_metric("reingest", "skip_check_elapsed_sec", round(skip_elapsed, 2)) + record_metric("reingest", "files_checked", files_written) + record_metric( + "reingest", + "skip_check_per_file_ms", + round(skip_elapsed * 1000 / max(files_written, 1), 1), + ) diff --git a/tests/benchmarks/test_knowledge_graph_bench.py b/tests/benchmarks/test_knowledge_graph_bench.py new file mode 100644 index 0000000..60236bc --- /dev/null +++ b/tests/benchmarks/test_knowledge_graph_bench.py @@ -0,0 +1,290 @@ +""" +Knowledge graph benchmarks — SQLite temporal KG at scale. + +Tests triple insertion throughput, query latency, temporal accuracy, +and SQLite concurrent access behavior. +""" + +import threading +import time + +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +@pytest.mark.benchmark +class TestTripleInsertionRate: + """Measure triples/sec at different scales.""" + + @pytest.mark.parametrize("n_triples", [200, 1_000, 5_000]) + def test_insertion_throughput(self, n_triples, tmp_path): + gen = PalaceDataGenerator(seed=42, scale="small") + entities, triples = gen.generate_kg_triples( + n_entities=min(n_triples // 2, 200), n_triples=n_triples + ) + + from mempalace.knowledge_graph import KnowledgeGraph + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + + # Insert entities first + for name, etype in entities: + kg.add_entity(name, etype) + + # Measure triple insertion + start = time.perf_counter() + for subject, predicate, obj, valid_from, valid_to in triples: + kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to) + elapsed = time.perf_counter() - start + + triples_per_sec = n_triples / max(elapsed, 0.001) + + record_metric("kg_insert", f"triples_per_sec_at_{n_triples}", round(triples_per_sec, 1)) + record_metric("kg_insert", f"elapsed_sec_at_{n_triples}", round(elapsed, 3)) + + +@pytest.mark.benchmark +class TestQueryEntityLatency: + """Query latency for entities with varying relationship counts.""" + + def test_query_latency_vs_relationships(self, tmp_path): + """Create entities with 10, 50, 100 relationships and measure query time.""" + from mempalace.knowledge_graph import KnowledgeGraph + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + + # Create a hub entity connected to many others + kg.add_entity("Hub", "person") + target_counts = [10, 50, 100] + + for target in target_counts: + for i in range(target): + entity_name = f"Node_{target}_{i}" + kg.add_entity(entity_name, "project") + kg.add_triple("Hub", "works_on", entity_name, valid_from="2025-01-01") + + # Measure query for Hub (which has sum(target_counts) relationships) + latencies = [] + for _ in range(20): + start = time.perf_counter() + kg.query_entity("Hub") + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + avg_ms = sum(latencies) / len(latencies) + total_rels = sum(target_counts) + + record_metric("kg_query", f"avg_ms_with_{total_rels}_rels", round(avg_ms, 2)) + record_metric("kg_query", "total_relationships", total_rels) + + +@pytest.mark.benchmark +class TestTimelinePerformance: + """timeline() with no entity filter does a full table scan.""" + + @pytest.mark.parametrize("n_triples", [200, 1_000, 5_000]) + def test_timeline_latency(self, n_triples, tmp_path): + from mempalace.knowledge_graph import KnowledgeGraph + + gen = PalaceDataGenerator(seed=42) + entities, triples = gen.generate_kg_triples( + n_entities=min(n_triples // 2, 200), n_triples=n_triples + ) + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + for name, etype in entities: + kg.add_entity(name, etype) + for subject, predicate, obj, valid_from, valid_to in triples: + kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to) + + # Measure timeline (no filter = full scan with LIMIT 100) + latencies = [] + for _ in range(10): + start = time.perf_counter() + kg.timeline() + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + avg_ms = sum(latencies) / len(latencies) + record_metric("kg_timeline", f"avg_ms_at_{n_triples}", round(avg_ms, 2)) + + +@pytest.mark.benchmark +class TestTemporalQueryAccuracy: + """Verify temporal filtering correctness at scale.""" + + def test_as_of_filtering(self, tmp_path): + """Insert triples with known temporal ranges, verify as_of queries.""" + from mempalace.knowledge_graph import KnowledgeGraph + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + + kg.add_entity("Alice", "person") + kg.add_entity("ProjectA", "project") + kg.add_entity("ProjectB", "project") + + # Alice worked on ProjectA from 2024-01 to 2024-06 + kg.add_triple( + "Alice", "works_on", "ProjectA", valid_from="2024-01-01", valid_to="2024-06-30" + ) + # Alice worked on ProjectB from 2024-07 onwards + kg.add_triple("Alice", "works_on", "ProjectB", valid_from="2024-07-01") + + # Add noise triples + gen = PalaceDataGenerator(seed=42) + entities, triples = gen.generate_kg_triples(n_entities=50, n_triples=500) + for name, etype in entities: + kg.add_entity(name, etype) + for subject, predicate, obj, valid_from, valid_to in triples: + kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to) + + # Query Alice as of March 2024 — should find ProjectA + result_march = kg.query_entity("Alice", as_of="2024-03-15") + # Query Alice as of September 2024 — should find ProjectB + result_sept = kg.query_entity("Alice", as_of="2024-09-15") + + record_metric( + "kg_temporal", + "march_query_results", + len(result_march) if isinstance(result_march, list) else 0, + ) + record_metric( + "kg_temporal", + "sept_query_results", + len(result_sept) if isinstance(result_sept, list) else 0, + ) + + +@pytest.mark.benchmark +class TestSQLiteConcurrentAccess: + """Test concurrent read/write behavior with SQLite (finding #8).""" + + def test_concurrent_writers(self, tmp_path): + """N threads writing triples simultaneously — count lock failures.""" + from mempalace.knowledge_graph import KnowledgeGraph + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + + # Pre-create entities + for i in range(100): + kg.add_entity(f"Entity_{i}", "concept") + + n_threads = 4 + triples_per_thread = 50 + lock_failures = [] + successes = [] + + def writer(thread_id): + fails = 0 + ok = 0 + for i in range(triples_per_thread): + try: + kg.add_triple( + f"Entity_{thread_id * 10}", + "relates_to", + f"Entity_{(thread_id * 10 + i) % 100}", + valid_from="2025-01-01", + ) + ok += 1 + except Exception: + fails += 1 + lock_failures.append(fails) + successes.append(ok) + + threads = [threading.Thread(target=writer, args=(t,)) for t in range(n_threads)] + start = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + elapsed = time.perf_counter() - start + + total_failures = sum(lock_failures) + total_successes = sum(successes) + + record_metric("kg_concurrent", "total_failures", total_failures) + record_metric("kg_concurrent", "total_successes", total_successes) + record_metric("kg_concurrent", "elapsed_sec", round(elapsed, 2)) + record_metric("kg_concurrent", "threads", n_threads) + record_metric("kg_concurrent", "triples_per_thread", triples_per_thread) + + def test_concurrent_read_write(self, tmp_path): + """Readers and writers running simultaneously.""" + from mempalace.knowledge_graph import KnowledgeGraph + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + + # Seed some data + for i in range(50): + kg.add_entity(f"E_{i}", "concept") + for i in range(200): + kg.add_triple(f"E_{i % 50}", "links", f"E_{(i + 1) % 50}", valid_from="2025-01-01") + + read_errors = [] + write_errors = [] + + def reader(): + fails = 0 + for i in range(50): + try: + kg.query_entity(f"E_{i % 50}") + except Exception: + fails += 1 + read_errors.append(fails) + + def writer(): + fails = 0 + for i in range(50): + try: + kg.add_triple( + f"E_{i % 50}", "new_rel", f"E_{(i + 7) % 50}", valid_from="2025-06-01" + ) + except Exception: + fails += 1 + write_errors.append(fails) + + threads = [ + threading.Thread(target=reader), + threading.Thread(target=reader), + threading.Thread(target=writer), + threading.Thread(target=writer), + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + record_metric("kg_concurrent_rw", "read_errors", sum(read_errors)) + record_metric("kg_concurrent_rw", "write_errors", sum(write_errors)) + + +@pytest.mark.benchmark +class TestKGStats: + """Measure stats() performance as graph grows.""" + + @pytest.mark.parametrize("n_triples", [200, 1_000, 5_000]) + def test_stats_latency(self, n_triples, tmp_path): + from mempalace.knowledge_graph import KnowledgeGraph + + gen = PalaceDataGenerator(seed=42) + entities, triples = gen.generate_kg_triples( + n_entities=min(n_triples // 2, 200), n_triples=n_triples + ) + + kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")) + for name, etype in entities: + kg.add_entity(name, etype) + for subject, predicate, obj, valid_from, valid_to in triples: + kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to) + + latencies = [] + for _ in range(10): + start = time.perf_counter() + kg.stats() + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + avg_ms = sum(latencies) / len(latencies) + record_metric("kg_stats", f"avg_ms_at_{n_triples}", round(avg_ms, 2)) diff --git a/tests/benchmarks/test_layers_bench.py b/tests/benchmarks/test_layers_bench.py new file mode 100644 index 0000000..7496588 --- /dev/null +++ b/tests/benchmarks/test_layers_bench.py @@ -0,0 +1,209 @@ +""" +Memory stack (layers.py) benchmarks. + +Tests MemoryStack.wake_up(), Layer1.generate(), and Layer2/L3 +at scale. Layer1 has the same unbounded col.get() as tool_status. +""" + +import time + +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +def _get_rss_mb(): + try: + import psutil + + return psutil.Process().memory_info().rss / (1024 * 1024) + except ImportError: + import resource + import platform + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +@pytest.mark.benchmark +class TestWakeUpCost: + """Measure wake_up() time (L0 + L1) at different palace sizes.""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_wakeup_latency(self, n_drawers, tmp_path, bench_scale): + """L0+L1 generation time grows with palace size because L1 fetches all.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + # Create identity file + identity_path = str(tmp_path / "identity.txt") + with open(identity_path, "w") as f: + f.write("I am a test AI. Traits: precise, fast.\n") + + from mempalace.layers import MemoryStack + + stack = MemoryStack(palace_path=palace_path, identity_path=identity_path) + + latencies = [] + for _ in range(5): + start = time.perf_counter() + text = stack.wake_up() + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + assert "L0" in text or "L1" in text or "IDENTITY" in text or "ESSENTIAL" in text + + avg_ms = sum(latencies) / len(latencies) + record_metric("layers_wakeup", f"avg_ms_at_{n_drawers}", round(avg_ms, 1)) + + +@pytest.mark.benchmark +class TestLayer1UnboundedFetch: + """Layer1.generate() fetches ALL drawers — same pattern as tool_status.""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_layer1_rss_growth(self, n_drawers, tmp_path): + """Track RSS from Layer1 fetching all drawers at different sizes.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + from mempalace.layers import Layer1 + + layer = Layer1(palace_path=palace_path) + + rss_before = _get_rss_mb() + start = time.perf_counter() + text = layer.generate() + elapsed_ms = (time.perf_counter() - start) * 1000 + rss_after = _get_rss_mb() + + rss_delta = rss_after - rss_before + assert "L1" in text + + record_metric("layer1", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) + record_metric("layer1", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2)) + + def test_layer1_wing_filtered(self, tmp_path): + """Wing-filtered Layer1 should fetch fewer drawers.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.layers import Layer1 + + wing = gen.wings[0] + + # Unfiltered + layer_all = Layer1(palace_path=palace_path) + start = time.perf_counter() + layer_all.generate() + unfiltered_ms = (time.perf_counter() - start) * 1000 + + # Wing-filtered + layer_wing = Layer1(palace_path=palace_path, wing=wing) + start = time.perf_counter() + layer_wing.generate() + filtered_ms = (time.perf_counter() - start) * 1000 + + record_metric("layer1_filter", "unfiltered_ms", round(unfiltered_ms, 1)) + record_metric("layer1_filter", "filtered_ms", round(filtered_ms, 1)) + if unfiltered_ms > 0: + record_metric( + "layer1_filter", "speedup_pct", round((1 - filtered_ms / unfiltered_ms) * 100, 1) + ) + + +@pytest.mark.benchmark +class TestWakeUpTokenBudget: + """Verify L0+L1 stays within token budget even at large palace sizes.""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_token_budget(self, n_drawers, tmp_path): + """L1 has MAX_CHARS=3200 cap. Verify it holds at scale.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + identity_path = str(tmp_path / "identity.txt") + with open(identity_path, "w") as f: + f.write("I am a benchmark AI.\n") + + from mempalace.layers import MemoryStack + + stack = MemoryStack(palace_path=palace_path, identity_path=identity_path) + text = stack.wake_up() + token_estimate = len(text) // 4 + + # Budget is ~600-900 tokens. Allow up to 1200 for safety margin. + record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate) + record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text)) + + assert ( + token_estimate < 1200 + ), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers" + + +@pytest.mark.benchmark +class TestLayer2Retrieval: + """Layer2 on-demand retrieval with filters.""" + + def test_layer2_latency(self, tmp_path, bench_scale): + """L2 retrieval with wing filter at scale.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.layers import Layer2 + + layer = Layer2(palace_path=palace_path) + wing = gen.wings[0] + + latencies = [] + for _ in range(10): + start = time.perf_counter() + layer.retrieve(wing=wing, n_results=10) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + avg_ms = sum(latencies) / len(latencies) + record_metric("layer2", "avg_retrieval_ms", round(avg_ms, 1)) + + +@pytest.mark.benchmark +class TestLayer3Search: + """Layer3 semantic search through the MemoryStack interface.""" + + def test_layer3_latency(self, tmp_path, bench_scale): + """L3 search latency through MemoryStack.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + identity_path = str(tmp_path / "identity.txt") + with open(identity_path, "w") as f: + f.write("I am a benchmark AI.\n") + + from mempalace.layers import MemoryStack + + stack = MemoryStack(palace_path=palace_path, identity_path=identity_path) + + queries = ["authentication", "database", "deployment", "testing", "monitoring"] + latencies = [] + for q in queries: + start = time.perf_counter() + stack.search(q, n_results=5) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + + avg_ms = sum(latencies) / len(latencies) + record_metric("layer3", "avg_search_ms", round(avg_ms, 1)) diff --git a/tests/benchmarks/test_mcp_bench.py b/tests/benchmarks/test_mcp_bench.py new file mode 100644 index 0000000..4e8330b --- /dev/null +++ b/tests/benchmarks/test_mcp_bench.py @@ -0,0 +1,226 @@ +""" +MCP server tool performance benchmarks. + +Validates production readiness findings: + - Finding #3: tool_status() unbounded col.get(include=["metadatas"]) → OOM + - Finding #7: _get_collection() re-instantiates PersistentClient every call + - Finding #3 variants: tool_list_wings(), tool_get_taxonomy() same pattern + +Calls MCP tool handler functions directly with monkeypatched _config. +""" + +import time + +import chromadb +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _make_palace(tmp_path, n_drawers, scale="small"): + """Create a palace with exactly n_drawers, return palace_path.""" + gen = PalaceDataGenerator(seed=42, scale=scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + return palace_path + + +def _patch_mcp_config(monkeypatch, palace_path, tmp_path): + """Monkeypatch mcp_server._config and _kg to point at test dirs.""" + from mempalace.config import MempalaceConfig + from mempalace.knowledge_graph import KnowledgeGraph + + cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg")) + # Override palace_path directly on the object + monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path}) + + import mempalace.mcp_server as mcp_mod + + monkeypatch.setattr(mcp_mod, "_config", cfg) + monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))) + + +def _get_rss_mb(): + """Get current process RSS in MB.""" + try: + import psutil + + return psutil.Process().memory_info().rss / (1024 * 1024) + except ImportError: + import resource + + # ru_maxrss is in KB on Linux, bytes on macOS + import platform + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +# ── Tests ──────────────────────────────────────────────────────────────── + + +@pytest.mark.benchmark +class TestToolStatusOOM: + """Finding #3: tool_status loads ALL metadata into memory.""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_tool_status_rss_growth(self, n_drawers, tmp_path, monkeypatch): + """Measure RSS growth from tool_status at different palace sizes.""" + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_status + + rss_before = _get_rss_mb() + result = tool_status() + rss_after = _get_rss_mb() + + rss_delta = rss_after - rss_before + assert "error" not in result, f"tool_status failed: {result}" + assert result["total_drawers"] == n_drawers + + record_metric("mcp_status", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2)) + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_tool_status_latency(self, n_drawers, tmp_path, monkeypatch): + """Measure tool_status response time at different palace sizes.""" + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_status + + # Warm up + tool_status() + + start = time.perf_counter() + result = tool_status() + elapsed_ms = (time.perf_counter() - start) * 1000 + + assert "error" not in result + record_metric("mcp_status", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) + + +@pytest.mark.benchmark +class TestToolListWingsUnbounded: + """Finding #3 variant: tool_list_wings also fetches ALL metadata.""" + + @pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000]) + def test_list_wings_latency(self, n_drawers, tmp_path, monkeypatch): + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_list_wings + + start = time.perf_counter() + result = tool_list_wings() + elapsed_ms = (time.perf_counter() - start) * 1000 + + assert "wings" in result + record_metric("mcp_list_wings", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) + + +@pytest.mark.benchmark +class TestToolGetTaxonomyUnbounded: + """Finding #3 variant: tool_get_taxonomy also fetches ALL metadata.""" + + @pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000]) + def test_get_taxonomy_latency(self, n_drawers, tmp_path, monkeypatch): + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_get_taxonomy + + start = time.perf_counter() + result = tool_get_taxonomy() + elapsed_ms = (time.perf_counter() - start) * 1000 + + assert "taxonomy" in result + record_metric("mcp_taxonomy", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) + + +@pytest.mark.benchmark +class TestClientReinstantiation: + """Finding #7: _get_collection() creates new PersistentClient every call.""" + + def test_reinstantiation_overhead(self, tmp_path, monkeypatch): + """Measure cost of 50 _get_collection() calls vs a cached client.""" + palace_path = _make_palace(tmp_path, 500) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import _get_collection + + n_calls = 50 + + # Measure re-instantiation (current behavior) + start = time.perf_counter() + for _ in range(n_calls): + col = _get_collection() + assert col is not None + uncached_ms = (time.perf_counter() - start) * 1000 + + # Measure cached client (what it should be) + client = chromadb.PersistentClient(path=palace_path) + cached_col = client.get_collection("mempalace_drawers") + start = time.perf_counter() + for _ in range(n_calls): + _ = cached_col.count() + cached_ms = (time.perf_counter() - start) * 1000 + + overhead_ratio = uncached_ms / max(cached_ms, 0.01) + + record_metric("client_reinstantiation", "uncached_total_ms", round(uncached_ms, 1)) + record_metric("client_reinstantiation", "cached_total_ms", round(cached_ms, 1)) + record_metric("client_reinstantiation", "overhead_ratio", round(overhead_ratio, 2)) + record_metric("client_reinstantiation", "n_calls", n_calls) + + +@pytest.mark.benchmark +class TestToolSearchLatency: + """tool_search uses query() not get(), should scale better.""" + + @pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000]) + def test_search_latency(self, n_drawers, tmp_path, monkeypatch): + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_search + + queries = ["authentication middleware", "database migration", "error handling"] + latencies = [] + for q in queries: + start = time.perf_counter() + result = tool_search(query=q, limit=5) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + assert "error" not in result + + avg_ms = sum(latencies) / len(latencies) + record_metric("mcp_search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1)) + + +@pytest.mark.benchmark +class TestDuplicateCheckCost: + """tool_add_drawer calls tool_check_duplicate first — measure overhead.""" + + @pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500]) + def test_duplicate_check_latency(self, n_drawers, tmp_path, monkeypatch): + palace_path = _make_palace(tmp_path, n_drawers) + _patch_mcp_config(monkeypatch, palace_path, tmp_path) + + from mempalace.mcp_server import tool_check_duplicate + + test_content = "This is unique test content for duplicate checking benchmark." + start = time.perf_counter() + result = tool_check_duplicate(content=test_content) + elapsed_ms = (time.perf_counter() - start) * 1000 + + assert "error" not in result + record_metric("mcp_duplicate_check", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1)) diff --git a/tests/benchmarks/test_memory_profile.py b/tests/benchmarks/test_memory_profile.py new file mode 100644 index 0000000..b299b2d --- /dev/null +++ b/tests/benchmarks/test_memory_profile.py @@ -0,0 +1,181 @@ +""" +Memory profiling benchmarks — detect leaks and measure RSS growth. + +Uses tracemalloc for heap snapshots and psutil/resource for RSS. +Targets the highest-risk code paths: + - Repeated search() calls (PersistentClient re-instantiation) + - Repeated tool_status() calls (unbounded metadata fetch) + - Layer1.generate() (fetches all drawers) +""" + +import tracemalloc + +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +def _get_rss_mb(): + try: + import psutil + + return psutil.Process().memory_info().rss / (1024 * 1024) + except ImportError: + import resource + import platform + + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +@pytest.mark.benchmark +class TestSearchMemoryProfile: + """Track RSS growth over repeated search_memories() calls.""" + + def test_search_rss_growth(self, tmp_path): + """Issue 200 searches and track RSS every 50 calls.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False) + + from mempalace.searcher import search_memories + + n_calls = 200 + check_interval = 50 + queries = ["authentication", "database", "deployment", "error handling", "testing"] + rss_readings = [] + rss_readings.append(("start", _get_rss_mb())) + + for i in range(n_calls): + q = queries[i % len(queries)] + search_memories(q, palace_path=palace_path, n_results=5) + if (i + 1) % check_interval == 0: + rss_readings.append((f"after_{i + 1}", _get_rss_mb())) + + start_rss = rss_readings[0][1] + end_rss = rss_readings[-1][1] + growth = end_rss - start_rss + + record_metric("memory_search", "rss_start_mb", round(start_rss, 2)) + record_metric("memory_search", "rss_end_mb", round(end_rss, 2)) + record_metric("memory_search", "rss_growth_mb", round(growth, 2)) + record_metric("memory_search", "n_calls", n_calls) + record_metric( + "memory_search", "growth_per_100_calls_mb", round(growth / (n_calls / 100), 2) + ) + + +@pytest.mark.benchmark +class TestToolStatusMemoryProfile: + """Track RSS growth from repeated tool_status() calls.""" + + def test_tool_status_repeated_calls(self, tmp_path, monkeypatch): + """tool_status loads ALL metadata each call — does it leak?""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.config import MempalaceConfig + from mempalace.knowledge_graph import KnowledgeGraph + import mempalace.mcp_server as mcp_mod + + cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg")) + monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path}) + monkeypatch.setattr(mcp_mod, "_config", cfg) + monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))) + + from mempalace.mcp_server import tool_status + + n_calls = 50 + rss_readings = [] + rss_readings.append(("start", _get_rss_mb())) + + for i in range(n_calls): + result = tool_status() + assert result["total_drawers"] == 2_000 + if (i + 1) % 10 == 0: + rss_readings.append((f"after_{i + 1}", _get_rss_mb())) + + start_rss = rss_readings[0][1] + end_rss = rss_readings[-1][1] + growth = end_rss - start_rss + + record_metric("memory_tool_status", "rss_start_mb", round(start_rss, 2)) + record_metric("memory_tool_status", "rss_end_mb", round(end_rss, 2)) + record_metric("memory_tool_status", "rss_growth_mb", round(growth, 2)) + record_metric("memory_tool_status", "n_calls", n_calls) + record_metric("memory_tool_status", "palace_size", 2_000) + + +@pytest.mark.benchmark +class TestLayer1MemoryProfile: + """Layer1.generate() fetches ALL drawers — same risk as tool_status.""" + + def test_layer1_repeated_generate(self, tmp_path): + """Layer1 fetches all drawers for scoring. Track memory over repeats.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.layers import Layer1 + + layer = Layer1(palace_path=palace_path) + + n_calls = 30 + rss_readings = [] + rss_readings.append(("start", _get_rss_mb())) + + for i in range(n_calls): + text = layer.generate() + assert "L1" in text + if (i + 1) % 10 == 0: + rss_readings.append((f"after_{i + 1}", _get_rss_mb())) + + start_rss = rss_readings[0][1] + end_rss = rss_readings[-1][1] + growth = end_rss - start_rss + + record_metric("memory_layer1", "rss_start_mb", round(start_rss, 2)) + record_metric("memory_layer1", "rss_end_mb", round(end_rss, 2)) + record_metric("memory_layer1", "rss_growth_mb", round(growth, 2)) + record_metric("memory_layer1", "n_calls", n_calls) + + +@pytest.mark.benchmark +class TestHeapSnapshot: + """Use tracemalloc to identify top memory allocators during search.""" + + def test_search_heap_top_allocators(self, tmp_path): + """Identify which code paths allocate the most memory during search.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False) + + from mempalace.searcher import search_memories + + tracemalloc.start() + snap_before = tracemalloc.take_snapshot() + + for i in range(100): + search_memories("test query", palace_path=palace_path, n_results=5) + + snap_after = tracemalloc.take_snapshot() + tracemalloc.stop() + + stats = snap_after.compare_to(snap_before, "lineno") + top_allocators = [] + for stat in stats[:10]: + top_allocators.append( + { + "file": str(stat.traceback), + "size_kb": round(stat.size / 1024, 1), + "count": stat.count, + } + ) + + total_growth_kb = sum(s["size_kb"] for s in top_allocators) + record_metric("heap_search", "top_10_growth_kb", round(total_growth_kb, 1)) + record_metric("heap_search", "n_searches", 100) diff --git a/tests/benchmarks/test_palace_boost.py b/tests/benchmarks/test_palace_boost.py new file mode 100644 index 0000000..ca90784 --- /dev/null +++ b/tests/benchmarks/test_palace_boost.py @@ -0,0 +1,176 @@ +""" +Palace boost validation — does wing/room filtering actually help? + +Quantifies the retrieval improvement from the palace spatial metaphor. +Uses planted needles to measure recall with and without filtering +at different scales. +""" + +import time + +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +@pytest.mark.benchmark +class TestFilteredVsUnfilteredRecall: + """Quantify palace boost: recall improvement from wing/room filtering.""" + + SIZES = [1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_palace_boost_recall(self, n_drawers, tmp_path, bench_scale): + """Compare recall@5 with/without wing filter at increasing scale.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + _, _, needle_info = gen.populate_palace_directly( + palace_path, n_drawers=n_drawers, include_needles=True + ) + + from mempalace.searcher import search_memories + + n_queries = min(10, len(needle_info)) + unfiltered_hits = 0 + wing_filtered_hits = 0 + room_filtered_hits = 0 + + for needle in needle_info[:n_queries]: + # Unfiltered search + result = search_memories(needle["query"], palace_path=palace_path, n_results=5) + texts = [h["text"] for h in result.get("results", [])] + if any("NEEDLE_" in t for t in texts[:5]): + unfiltered_hits += 1 + + # Wing-filtered search + result = search_memories( + needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5 + ) + texts = [h["text"] for h in result.get("results", [])] + if any("NEEDLE_" in t for t in texts[:5]): + wing_filtered_hits += 1 + + # Wing+room filtered search + result = search_memories( + needle["query"], + palace_path=palace_path, + wing=needle["wing"], + room=needle["room"], + n_results=5, + ) + texts = [h["text"] for h in result.get("results", [])] + if any("NEEDLE_" in t for t in texts[:5]): + room_filtered_hits += 1 + + recall_none = unfiltered_hits / max(n_queries, 1) + recall_wing = wing_filtered_hits / max(n_queries, 1) + recall_room = room_filtered_hits / max(n_queries, 1) + + boost_wing = recall_wing - recall_none + boost_room = recall_room - recall_none + + record_metric("palace_boost", f"recall_unfiltered_at_{n_drawers}", round(recall_none, 3)) + record_metric("palace_boost", f"recall_wing_filtered_at_{n_drawers}", round(recall_wing, 3)) + record_metric("palace_boost", f"recall_room_filtered_at_{n_drawers}", round(recall_room, 3)) + record_metric("palace_boost", f"wing_boost_at_{n_drawers}", round(boost_wing, 3)) + record_metric("palace_boost", f"room_boost_at_{n_drawers}", round(boost_room, 3)) + + +@pytest.mark.benchmark +class TestFilterLatencyBenefit: + """Does filtering reduce query latency by narrowing the search space?""" + + def test_filter_speedup(self, tmp_path, bench_scale): + """Compare latency: no filter vs wing vs wing+room.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=5_000, include_needles=False) + + from mempalace.searcher import search_memories + + wing = gen.wings[0] + room = gen.rooms_by_wing[wing][0] + query = "authentication middleware optimization" + n_runs = 10 + + # No filter + latencies_none = [] + for _ in range(n_runs): + start = time.perf_counter() + search_memories(query, palace_path=palace_path, n_results=5) + latencies_none.append((time.perf_counter() - start) * 1000) + + # Wing filter + latencies_wing = [] + for _ in range(n_runs): + start = time.perf_counter() + search_memories(query, palace_path=palace_path, wing=wing, n_results=5) + latencies_wing.append((time.perf_counter() - start) * 1000) + + # Wing + room filter + latencies_room = [] + for _ in range(n_runs): + start = time.perf_counter() + search_memories(query, palace_path=palace_path, wing=wing, room=room, n_results=5) + latencies_room.append((time.perf_counter() - start) * 1000) + + avg_none = sum(latencies_none) / len(latencies_none) + avg_wing = sum(latencies_wing) / len(latencies_wing) + avg_room = sum(latencies_room) / len(latencies_room) + + record_metric("filter_latency", "avg_unfiltered_ms", round(avg_none, 1)) + record_metric("filter_latency", "avg_wing_filtered_ms", round(avg_wing, 1)) + record_metric("filter_latency", "avg_room_filtered_ms", round(avg_room, 1)) + if avg_none > 0: + record_metric( + "filter_latency", "wing_speedup_pct", round((1 - avg_wing / avg_none) * 100, 1) + ) + record_metric( + "filter_latency", "room_speedup_pct", round((1 - avg_room / avg_none) * 100, 1) + ) + + +@pytest.mark.benchmark +class TestBoostAtIncreasingScale: + """Does the palace boost increase as the palace grows?""" + + def test_boost_scaling(self, tmp_path, bench_scale): + """Measure wing-filtered recall improvement at multiple sizes.""" + sizes = [500, 1_000, 2_500] + boosts = [] + + for size in sizes: + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / f"palace_{size}") + _, _, needle_info = gen.populate_palace_directly( + palace_path, n_drawers=size, include_needles=True + ) + + from mempalace.searcher import search_memories + + n_queries = min(8, len(needle_info)) + unfiltered_hits = 0 + filtered_hits = 0 + + for needle in needle_info[:n_queries]: + result = search_memories(needle["query"], palace_path=palace_path, n_results=5) + if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]): + unfiltered_hits += 1 + + result = search_memories( + needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5 + ) + if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]): + filtered_hits += 1 + + recall_none = unfiltered_hits / max(n_queries, 1) + recall_filtered = filtered_hits / max(n_queries, 1) + boost = recall_filtered - recall_none + boosts.append({"size": size, "boost": boost}) + + record_metric("boost_scaling", "boosts_by_size", boosts) + # Check if boost increases with scale (the hypothesis) + if len(boosts) >= 2: + trend_positive = boosts[-1]["boost"] >= boosts[0]["boost"] + record_metric("boost_scaling", "trend_positive", trend_positive) diff --git a/tests/benchmarks/test_recall_threshold.py b/tests/benchmarks/test_recall_threshold.py new file mode 100644 index 0000000..afe2323 --- /dev/null +++ b/tests/benchmarks/test_recall_threshold.py @@ -0,0 +1,182 @@ +""" +Recall threshold test — find the per-bucket size where retrieval breaks. + +The palace_boost tests showed room-filtered recall of 1.0, but only because +each room had ~333 drawers. This test concentrates ALL drawers into a single +wing+room to find the actual embedding model limit. +""" + +import hashlib +import os +from datetime import datetime + +import chromadb +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +NEEDLE_TOPICS = [ + "Fibonacci sequence optimization uses memoization with O(n) space complexity", + "PostgreSQL vacuum autovacuum threshold set to 50 percent for table users", + "Redis cluster failover timeout configured at 30 seconds with sentinel monitoring", + "Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization", + "GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds", + "JWT token rotation policy requires refresh every 15 minutes with sliding window", + "Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each", + "Docker multi-stage build reduces image size from 1.2GB to 180MB for production", + "Apache Kafka consumer group rebalance timeout set to 45 seconds", + "MongoDB change streams resume token persisted every 100 operations", +] + +NEEDLE_QUERIES = [ + "Fibonacci sequence optimization memoization", + "PostgreSQL vacuum autovacuum threshold", + "Redis cluster failover timeout sentinel", + "Kubernetes horizontal pod autoscaler CPU", + "GraphQL subscription WebSocket heartbeat", + "JWT token rotation policy refresh", + "Elasticsearch index sharding primary replica", + "Docker multi-stage build image size production", + "Apache Kafka consumer group rebalance", + "MongoDB change streams resume token", +] + + +def _populate_single_room(palace_path, n_drawers, n_needles=10): + """Pack all drawers into one wing+room, plant needles, return queries.""" + gen = PalaceDataGenerator(seed=42, scale="small") + os.makedirs(palace_path, exist_ok=True) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_or_create_collection("mempalace_drawers") + + batch_size = 500 + docs, ids, metas = [], [], [] + + # Plant needles + for i in range(n_needles): + needle_id = f"NEEDLE_{i:04d}" + content = f"{needle_id}: {NEEDLE_TOPICS[i]}. Unique planted needle for threshold test." + drawer_id = f"drawer_single_room_{hashlib.md5(needle_id.encode()).hexdigest()[:16]}" + docs.append(content) + ids.append(drawer_id) + metas.append( + { + "wing": "concentrated", + "room": "single_room", + "source_file": f"needle_{i}.txt", + "chunk_index": 0, + "added_by": "threshold_bench", + "filed_at": datetime.now().isoformat(), + } + ) + + # Fill with noise — all in the SAME room + remaining = n_drawers - len(docs) + for i in range(remaining): + content = gen._random_text(400, 800) + drawer_id = f"drawer_single_room_{hashlib.md5(f'noise_{i}'.encode()).hexdigest()[:16]}" + docs.append(content) + ids.append(drawer_id) + metas.append( + { + "wing": "concentrated", + "room": "single_room", + "source_file": f"noise_{i:06d}.txt", + "chunk_index": i % 10, + "added_by": "threshold_bench", + "filed_at": datetime.now().isoformat(), + } + ) + + if len(docs) >= batch_size: + col.add(documents=docs, ids=ids, metadatas=metas) + docs, ids, metas = [], [], [] + + if docs: + col.add(documents=docs, ids=ids, metadatas=metas) + + return client, col + + +@pytest.mark.benchmark +class TestRecallThresholdSingleRoom: + """ + All drawers in one room — isolates the embedding model's retrieval limit. + + Room filtering can't help here. This is the true ceiling. + """ + + SIZES = [250, 500, 1_000, 2_000, 3_000, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_single_room_recall(self, n_drawers, tmp_path): + """Recall@5 and @10 with all drawers in one bucket.""" + palace_path = str(tmp_path / "palace") + _populate_single_room(palace_path, n_drawers, n_needles=10) + + from mempalace.searcher import search_memories + + hits_at_5 = 0 + hits_at_10 = 0 + n_queries = len(NEEDLE_QUERIES) + + for i, query in enumerate(NEEDLE_QUERIES): + result = search_memories( + query, + palace_path=palace_path, + wing="concentrated", + room="single_room", + n_results=10, + ) + if "error" in result: + continue + + texts = [h["text"] for h in result.get("results", [])] + needle_id = f"NEEDLE_{i:04d}" + + found_at_5 = any(needle_id in t for t in texts[:5]) + found_at_10 = any(needle_id in t for t in texts[:10]) + + if found_at_5: + hits_at_5 += 1 + if found_at_10: + hits_at_10 += 1 + + recall_5 = hits_at_5 / n_queries + recall_10 = hits_at_10 / n_queries + + record_metric("single_room_recall", f"recall_at_5_at_{n_drawers}", round(recall_5, 3)) + record_metric("single_room_recall", f"recall_at_10_at_{n_drawers}", round(recall_10, 3)) + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_single_room_no_filter_recall(self, n_drawers, tmp_path): + """Same test but WITHOUT wing/room filter — pure unfiltered search.""" + palace_path = str(tmp_path / "palace") + _populate_single_room(palace_path, n_drawers, n_needles=10) + + from mempalace.searcher import search_memories + + hits_at_5 = 0 + hits_at_10 = 0 + n_queries = len(NEEDLE_QUERIES) + + for i, query in enumerate(NEEDLE_QUERIES): + result = search_memories(query, palace_path=palace_path, n_results=10) + if "error" in result: + continue + + texts = [h["text"] for h in result.get("results", [])] + needle_id = f"NEEDLE_{i:04d}" + + if any(needle_id in t for t in texts[:5]): + hits_at_5 += 1 + if any(needle_id in t for t in texts[:10]): + hits_at_10 += 1 + + recall_5 = hits_at_5 / n_queries + recall_10 = hits_at_10 / n_queries + + record_metric("single_room_unfiltered", f"recall_at_5_at_{n_drawers}", round(recall_5, 3)) + record_metric("single_room_unfiltered", f"recall_at_10_at_{n_drawers}", round(recall_10, 3)) diff --git a/tests/benchmarks/test_search_bench.py b/tests/benchmarks/test_search_bench.py new file mode 100644 index 0000000..3cb7785 --- /dev/null +++ b/tests/benchmarks/test_search_bench.py @@ -0,0 +1,234 @@ +""" +Search performance benchmarks. + +Measures query latency, recall@k, and concurrent search behavior +as palace size grows. Uses planted needles for recall measurement. +""" + +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + +from tests.benchmarks.data_generator import PalaceDataGenerator +from tests.benchmarks.report import record_metric + + +@pytest.mark.benchmark +class TestSearchLatencyVsSize: + """Query latency scaling as palace grows.""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_search_latency_curve(self, n_drawers, tmp_path, bench_scale): + """Measure average search latency at different palace sizes.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False) + + from mempalace.searcher import search_memories + + queries = [ + "authentication middleware", + "database optimization", + "error handling patterns", + "deployment configuration", + "testing strategy", + ] + + latencies = [] + for q in queries: + start = time.perf_counter() + result = search_memories(q, palace_path=palace_path, n_results=5) + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies.append(elapsed_ms) + assert "error" not in result + + avg_ms = sum(latencies) / len(latencies) + sorted_lat = sorted(latencies) + p50_ms = sorted_lat[len(sorted_lat) // 2] + p95_ms = sorted_lat[int(len(sorted_lat) * 0.95)] + + record_metric("search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1)) + record_metric("search", f"p50_ms_at_{n_drawers}", round(p50_ms, 1)) + record_metric("search", f"p95_ms_at_{n_drawers}", round(p95_ms, 1)) + + +@pytest.mark.benchmark +class TestSearchRecallAtScale: + """Planted needle recall — does accuracy degrade as palace grows?""" + + SIZES = [500, 1_000, 2_500, 5_000] + + @pytest.mark.parametrize("n_drawers", SIZES) + def test_recall_at_k(self, n_drawers, tmp_path, bench_scale): + """Recall@5 and Recall@10 using planted needles.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + _, _, needle_info = gen.populate_palace_directly( + palace_path, n_drawers=n_drawers, include_needles=True + ) + + from mempalace.searcher import search_memories + + hits_at_5 = 0 + hits_at_10 = 0 + total_needle_queries = min(10, len(needle_info)) + + for needle in needle_info[:total_needle_queries]: + result = search_memories(needle["query"], palace_path=palace_path, n_results=10) + if "error" in result: + continue + + texts = [h["text"] for h in result.get("results", [])] + + # Check if needle content appears in top 5 + found_at_5 = any("NEEDLE_" in t for t in texts[:5]) + found_at_10 = any("NEEDLE_" in t for t in texts[:10]) + + if found_at_5: + hits_at_5 += 1 + if found_at_10: + hits_at_10 += 1 + + recall_at_5 = hits_at_5 / max(total_needle_queries, 1) + recall_at_10 = hits_at_10 / max(total_needle_queries, 1) + + record_metric("search_recall", f"recall_at_5_at_{n_drawers}", round(recall_at_5, 3)) + record_metric("search_recall", f"recall_at_10_at_{n_drawers}", round(recall_at_10, 3)) + + +@pytest.mark.benchmark +class TestSearchFilteredVsUnfiltered: + """Compare search performance with and without wing/room filters.""" + + def test_filter_impact(self, tmp_path, bench_scale): + """Measure latency and recall difference with wing filtering.""" + gen = PalaceDataGenerator(seed=42, scale=bench_scale) + palace_path = str(tmp_path / "palace") + _, _, needle_info = gen.populate_palace_directly( + palace_path, n_drawers=2_000, include_needles=True + ) + + from mempalace.searcher import search_memories + + filtered_latencies = [] + unfiltered_latencies = [] + filtered_hits = 0 + unfiltered_hits = 0 + n_queries = min(10, len(needle_info)) + + for needle in needle_info[:n_queries]: + # Unfiltered + start = time.perf_counter() + result_unfiltered = search_memories( + needle["query"], palace_path=palace_path, n_results=5 + ) + unfiltered_latencies.append((time.perf_counter() - start) * 1000) + if any("NEEDLE_" in h["text"] for h in result_unfiltered.get("results", [])[:5]): + unfiltered_hits += 1 + + # Filtered by wing + start = time.perf_counter() + result_filtered = search_memories( + needle["query"], + palace_path=palace_path, + wing=needle["wing"], + n_results=5, + ) + filtered_latencies.append((time.perf_counter() - start) * 1000) + if any("NEEDLE_" in h["text"] for h in result_filtered.get("results", [])[:5]): + filtered_hits += 1 + + avg_unfiltered = sum(unfiltered_latencies) / max(len(unfiltered_latencies), 1) + avg_filtered = sum(filtered_latencies) / max(len(filtered_latencies), 1) + latency_improvement = ((avg_unfiltered - avg_filtered) / max(avg_unfiltered, 0.01)) * 100 + + record_metric("search_filter", "avg_unfiltered_ms", round(avg_unfiltered, 1)) + record_metric("search_filter", "avg_filtered_ms", round(avg_filtered, 1)) + record_metric("search_filter", "latency_improvement_pct", round(latency_improvement, 1)) + record_metric( + "search_filter", "unfiltered_recall_at_5", round(unfiltered_hits / max(n_queries, 1), 3) + ) + record_metric( + "search_filter", "filtered_recall_at_5", round(filtered_hits / max(n_queries, 1), 3) + ) + + +@pytest.mark.benchmark +class TestConcurrentSearch: + """Concurrent query performance — tests PersistentClient contention.""" + + def test_concurrent_queries(self, tmp_path): + """Issue N simultaneous queries and measure p50/p95/p99.""" + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.searcher import search_memories + + queries = [ + "authentication", + "database", + "deployment", + "error handling", + "testing", + "monitoring", + "caching", + "middleware", + "serialization", + "validation", + ] * 3 # 30 total queries + + def run_search(query): + start = time.perf_counter() + result = search_memories(query, palace_path=palace_path, n_results=5) + elapsed = (time.perf_counter() - start) * 1000 + return elapsed, "error" not in result + + # Concurrent execution + latencies = [] + errors = 0 + with ThreadPoolExecutor(max_workers=4) as executor: + futures = {executor.submit(run_search, q): q for q in queries} + for future in as_completed(futures): + elapsed, success = future.result() + latencies.append(elapsed) + if not success: + errors += 1 + + sorted_lat = sorted(latencies) + n = len(sorted_lat) + + record_metric("concurrent_search", "p50_ms", round(sorted_lat[n // 2], 1)) + record_metric("concurrent_search", "p95_ms", round(sorted_lat[int(n * 0.95)], 1)) + record_metric("concurrent_search", "p99_ms", round(sorted_lat[int(n * 0.99)], 1)) + record_metric("concurrent_search", "avg_ms", round(sum(sorted_lat) / n, 1)) + record_metric("concurrent_search", "error_count", errors) + record_metric("concurrent_search", "total_queries", len(queries)) + record_metric("concurrent_search", "workers", 4) + + +@pytest.mark.benchmark +class TestSearchNResultsScaling: + """How does n_results affect query latency?""" + + @pytest.mark.parametrize("n_results", [1, 5, 10, 25, 50]) + def test_n_results_latency(self, n_results, tmp_path): + gen = PalaceDataGenerator(seed=42, scale="small") + palace_path = str(tmp_path / "palace") + gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False) + + from mempalace.searcher import search_memories + + latencies = [] + for _ in range(5): + start = time.perf_counter() + search_memories( + "authentication middleware", palace_path=palace_path, n_results=n_results + ) + latencies.append((time.perf_counter() - start) * 1000) + + avg_ms = sum(latencies) / len(latencies) + record_metric("search_n_results", f"avg_ms_at_n_{n_results}", round(avg_ms, 1)) diff --git a/tests/conftest.py b/tests/conftest.py index 22b5e42..7a3e55a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,24 @@ from mempalace.config import MempalaceConfig # noqa: E402 from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402 +@pytest.fixture(autouse=True) +def _reset_mcp_cache(): + """Reset the MCP server's cached ChromaDB client/collection between tests.""" + + def _clear_cache(): + try: + from mempalace import mcp_server + + mcp_server._client_cache = None + mcp_server._collection_cache = None + except (ImportError, AttributeError): + pass + + _clear_cache() + yield + _clear_cache() + + @pytest.fixture(scope="session", autouse=True) def _isolate_home(): """Ensure HOME points to a temp dir for the entire test session. @@ -84,7 +102,9 @@ def collection(palace_path): """A ChromaDB collection pre-seeded in the temp palace.""" client = chromadb.PersistentClient(path=palace_path) col = client.get_or_create_collection("mempalace_drawers") - return col + yield col + client.delete_collection("mempalace_drawers") + del client @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..879d276 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,609 @@ +"""Tests for mempalace.cli — the main CLI dispatcher.""" + +import argparse +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.cli import ( + cmd_compress, + cmd_hook, + cmd_init, + cmd_instructions, + cmd_mine, + cmd_repair, + cmd_search, + cmd_split, + cmd_status, + cmd_wakeup, + main, +) + + +# ── cmd_status ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_status_default_palace(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None) + mock_miner = MagicMock() + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + cmd_status(args) + mock_miner.status.assert_called_once_with(palace_path="/fake/palace") + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_status_custom_palace(mock_config_cls): + args = argparse.Namespace(palace="~/my_palace") + mock_miner = MagicMock() + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + cmd_status(args) + import os + + expected = os.path.expanduser("~/my_palace") + mock_miner.status.assert_called_once_with(palace_path=expected) + + +# ── cmd_search ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_search_calls_search(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + palace=None, query="test query", wing="mywing", room="myroom", results=3 + ) + with patch("mempalace.searcher.search") as mock_search: + cmd_search(args) + mock_search.assert_called_once_with( + query="test query", + palace_path="/fake/palace", + wing="mywing", + room="myroom", + n_results=3, + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_search_error_exits(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5) + from mempalace.searcher import SearchError + + with patch("mempalace.searcher.search", side_effect=SearchError("fail")): + with pytest.raises(SystemExit) as exc_info: + cmd_search(args) + assert exc_info.value.code == 1 + + +# ── cmd_instructions ─────────────────────────────────────────────────── + + +def test_cmd_instructions_calls_run_instructions(): + args = argparse.Namespace(name="help") + with patch("mempalace.instructions_cli.run_instructions") as mock_run: + cmd_instructions(args) + mock_run.assert_called_once_with(name="help") + + +# ── cmd_hook ─────────────────────────────────────────────────────────── + + +def test_cmd_hook_calls_run_hook(): + args = argparse.Namespace(hook="session-start", harness="claude-code") + with patch("mempalace.hooks_cli.run_hook") as mock_run: + cmd_hook(args) + mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code") + + +# ── cmd_init ─────────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_no_entities(mock_config_cls, tmp_path): + args = argparse.Namespace(dir=str(tmp_path), yes=True) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=[]), + patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms, + ): + cmd_init(args) + mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True) + mock_config_cls.return_value.init.assert_called_once() + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_with_entities(mock_config_cls, tmp_path): + fake_files = [tmp_path / "a.txt"] + detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []} + confirmed = {"people": ["Alice"], "projects": []} + args = argparse.Namespace(dir=str(tmp_path), yes=True) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files), + patch("mempalace.entity_detector.detect_entities", return_value=detected), + patch("mempalace.entity_detector.confirm_entities", return_value=confirmed), + patch("mempalace.room_detector_local.detect_rooms_local"), + patch("builtins.open", MagicMock()), + ): + cmd_init(args) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys): + """When entities detected but total is 0, prints 'No entities' message.""" + fake_files = [tmp_path / "a.txt"] + detected = {"people": [], "projects": [], "uncertain": []} + args = argparse.Namespace(dir=str(tmp_path), yes=False) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files), + patch("mempalace.entity_detector.detect_entities", return_value=detected), + patch("mempalace.room_detector_local.detect_rooms_local"), + ): + cmd_init(args) + out = capsys.readouterr().out + assert "No entities detected" in out + + +# ── cmd_mine ─────────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_projects_mode(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/src", + palace=None, + mode="projects", + wing=None, + agent="mempalace", + limit=0, + dry_run=False, + no_gitignore=False, + include_ignored=[], + extract="exchange", + ) + with patch("mempalace.miner.mine") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once_with( + project_dir="/src", + palace_path="/fake/palace", + wing_override=None, + agent="mempalace", + limit=0, + dry_run=False, + respect_gitignore=True, + include_ignored=[], + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_convos_mode(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/chats", + palace=None, + mode="convos", + wing="mywing", + agent="me", + limit=10, + dry_run=True, + no_gitignore=False, + include_ignored=[], + extract="general", + ) + with patch("mempalace.convo_miner.mine_convos") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once_with( + convo_dir="/chats", + palace_path="/fake/palace", + wing="mywing", + agent="me", + limit=10, + dry_run=True, + extract_mode="general", + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_include_ignored_comma_split(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/src", + palace=None, + mode="projects", + wing=None, + agent="mempalace", + limit=0, + dry_run=False, + no_gitignore=False, + include_ignored=["a.txt,b.txt", "c.txt"], + extract="exchange", + ) + with patch("mempalace.miner.mine") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once() + call_kwargs = mock_mine.call_args[1] + assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"] + + +# ── cmd_wakeup ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_wakeup(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None) + mock_stack = MagicMock() + mock_stack.wake_up.return_value = "Hello world context" + with patch("mempalace.layers.MemoryStack", return_value=mock_stack): + cmd_wakeup(args) + out = capsys.readouterr().out + assert "Hello world context" in out + assert "tokens" in out + + +# ── cmd_split ────────────────────────────────────────────────────────── + + +def test_cmd_split_basic(): + args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2) + with patch("mempalace.split_mega_files.main") as mock_main: + cmd_split(args) + mock_main.assert_called_once() + + +def test_cmd_split_all_options(): + args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5) + with patch("mempalace.split_mega_files.main") as mock_main: + cmd_split(args) + mock_main.assert_called_once() + # sys.argv should be restored + assert sys.argv[0] != "mempalace split" + + +# ── main() argparse dispatch ────────────────────────────────────────── + + +def test_main_no_args_prints_help(capsys): + with patch("sys.argv", ["mempalace"]): + main() + out = capsys.readouterr().out + assert "MemPalace" in out + + +def test_main_status_dispatches(): + with ( + patch("sys.argv", ["mempalace", "status"]), + patch("mempalace.cli.cmd_status") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_search_dispatches(): + with ( + patch("sys.argv", ["mempalace", "search", "my query"]), + patch("mempalace.cli.cmd_search") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_init_dispatches(): + with ( + patch("sys.argv", ["mempalace", "init", "/some/dir"]), + patch("mempalace.cli.cmd_init") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_mine_dispatches(): + with ( + patch("sys.argv", ["mempalace", "mine", "/some/dir"]), + patch("mempalace.cli.cmd_mine") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_wakeup_dispatches(): + with ( + patch("sys.argv", ["mempalace", "wake-up"]), + patch("mempalace.cli.cmd_wakeup") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_split_dispatches(): + with ( + patch("sys.argv", ["mempalace", "split", "/chats"]), + patch("mempalace.cli.cmd_split") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_hook_no_subcommand_prints_help(capsys): + with patch("sys.argv", ["mempalace", "hook"]): + main() + out = capsys.readouterr().out + assert "hook" in out.lower() or "run" in out.lower() + + +def test_main_hook_run_dispatches(): + with ( + patch( + "sys.argv", + ["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"], + ), + patch("mempalace.cli.cmd_hook") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_instructions_no_subcommand_prints_help(capsys): + with patch("sys.argv", ["mempalace", "instructions"]): + main() + out = capsys.readouterr().out + assert "instructions" in out.lower() or "init" in out.lower() + + +def test_main_instructions_dispatches(): + with ( + patch("sys.argv", ["mempalace", "instructions", "help"]), + patch("mempalace.cli.cmd_instructions") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_repair_dispatches(): + with ( + patch("sys.argv", ["mempalace", "repair"]), + patch("mempalace.cli.cmd_repair") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_compress_dispatches(): + with ( + patch("sys.argv", ["mempalace", "compress"]), + patch("mempalace.cli.cmd_compress") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +# ── cmd_repair ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys): + mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent") + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "No palace found" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_client = MagicMock() + mock_client.get_collection.side_effect = Exception("corrupt db") + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Error reading palace" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Nothing to repair" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_new_col = MagicMock() + mock_client.create_collection.return_value = mock_new_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Repair complete" in out + assert "2 drawers rebuilt" in out + + +# ── cmd_compress ─────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_no_palace(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_chromadb.PersistentClient.side_effect = Exception("no palace") + with ( + patch.dict("sys.modules", {"chromadb": mock_chromadb}), + pytest.raises(SystemExit), + ): + cmd_compress(args) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_no_drawers(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_compress(args) + out = capsys.readouterr().out + assert "No drawers found" in out + + +def _make_mock_dialect_module(dialect_instance): + """Create a mock dialect module with a Dialect class that returns the given instance.""" + mock_mod = MagicMock() + mock_mod.Dialect.return_value = dialect_instance + mock_mod.Dialect.from_config.return_value = dialect_instance + mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4) + return mock_mod + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_dry_run(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.side_effect = [ + { + "documents": ["some long text here for testing"], + "metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}], + "ids": ["id1"], + }, + {"documents": [], "metadatas": [], "ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect.compress.return_value = "compressed" + mock_dialect.compression_stats.return_value = { + "original_chars": 100, + "compressed_chars": 30, + "original_tokens": 25, + "compressed_tokens": 8, + "ratio": 3.3, + } + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "dry run" in out.lower() + assert "Compressing" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + config_file = tmp_path / "entities.json" + config_file.write_text('{"people": [], "projects": []}') + args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file)) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "Loaded entity config" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_stores_results(mock_config_cls, capsys): + """Non-dry-run compress stores to mempalace_compressed collection.""" + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.side_effect = [ + { + "documents": ["text"], + "metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}], + "ids": ["id1"], + }, + {"documents": [], "metadatas": [], "ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_comp_col = MagicMock() + mock_client.get_or_create_collection.return_value = mock_comp_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect.compress.return_value = "compressed" + mock_dialect.compression_stats.return_value = { + "original_chars": 100, + "compressed_chars": 30, + "original_tokens": 25, + "compressed_tokens": 8, + "ratio": 3.3, + } + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "Stored" in out + mock_comp_col.upsert.assert_called_once() diff --git a/tests/test_config_extra.py b/tests/test_config_extra.py new file mode 100644 index 0000000..d0d9b5d --- /dev/null +++ b/tests/test_config_extra.py @@ -0,0 +1,79 @@ +"""Extra tests for mempalace.config to cover remaining gaps.""" + +import json +import os + +from mempalace.config import MempalaceConfig + + +def test_config_bad_json(tmp_path): + """Bad JSON in config file falls back to empty.""" + (tmp_path / "config.json").write_text("not json", encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.palace_path # still returns default + + +def test_people_map_from_file(tmp_path): + (tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {"bob": "Robert"} + + +def test_people_map_bad_json(tmp_path): + (tmp_path / "people_map.json").write_text("bad", encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {} + + +def test_people_map_missing(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {} + + +def test_topic_wings_default(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert isinstance(cfg.topic_wings, list) + assert "emotions" in cfg.topic_wings + + +def test_hall_keywords_default(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert isinstance(cfg.hall_keywords, dict) + assert "technical" in cfg.hall_keywords + + +def test_init_idempotent(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + cfg.init() + cfg.init() # second call should not overwrite + with open(tmp_path / "config.json") as f: + data = json.load(f) + assert "palace_path" in data + + +def test_save_people_map(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + result = cfg.save_people_map({"alice": "Alice Smith"}) + assert result.exists() + with open(result) as f: + data = json.load(f) + assert data["alice"] == "Alice Smith" + + +def test_env_mempal_palace_path(tmp_path): + """MEMPAL_PALACE_PATH (legacy) should also work.""" + os.environ.pop("MEMPALACE_PALACE_PATH", None) + os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path" + try: + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.palace_path == "/legacy/path" + finally: + del os.environ["MEMPAL_PALACE_PATH"] + + +def test_collection_name_from_config(tmp_path): + (tmp_path / "config.json").write_text( + json.dumps({"collection_name": "custom_col"}), encoding="utf-8" + ) + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.collection_name == "custom_col" diff --git a/tests/test_convo_miner.py b/tests/test_convo_miner.py index 788c46d..0ac0019 100644 --- a/tests/test_convo_miner.py +++ b/tests/test_convo_miner.py @@ -23,4 +23,4 @@ def test_convo_mining(): results = col.query(query_texts=["memory persistence"], n_results=1) assert len(results["documents"][0]) > 0 - shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_convo_miner_unit.py b/tests/test_convo_miner_unit.py new file mode 100644 index 0000000..3c7e8f2 --- /dev/null +++ b/tests/test_convo_miner_unit.py @@ -0,0 +1,102 @@ +"""Unit tests for convo_miner pure functions (no chromadb needed).""" + +from mempalace.convo_miner import ( + chunk_exchanges, + detect_convo_room, + scan_convos, +) + + +class TestChunkExchanges: + def test_exchange_chunking(self): + content = ( + "> What is memory?\n" + "Memory is persistence of information over time.\n\n" + "> Why does it matter?\n" + "It enables continuity across sessions and conversations.\n\n" + "> How do we build it?\n" + "With structured storage and retrieval mechanisms.\n" + ) + chunks = chunk_exchanges(content) + assert len(chunks) >= 2 + assert all("content" in c and "chunk_index" in c for c in chunks) + + def test_paragraph_fallback(self): + """Content without '>' lines falls back to paragraph chunking.""" + content = ( + "This is a long paragraph about memory systems. " * 10 + "\n\n" + "This is another paragraph about storage. " * 10 + "\n\n" + "And a third paragraph about retrieval. " * 10 + ) + chunks = chunk_exchanges(content) + assert len(chunks) >= 2 + + def test_paragraph_line_group_fallback(self): + """Long content with no paragraph breaks chunks by line groups.""" + lines = [f"Line {i}: some content that is meaningful" for i in range(60)] + content = "\n".join(lines) + chunks = chunk_exchanges(content) + assert len(chunks) >= 1 + + def test_empty_content(self): + chunks = chunk_exchanges("") + assert chunks == [] + + def test_short_content_skipped(self): + chunks = chunk_exchanges("> hi\nbye") + # Too short to produce chunks (below MIN_CHUNK_SIZE) + assert isinstance(chunks, list) + + +class TestDetectConvoRoom: + def test_technical_room(self): + content = "Let me debug this python function and fix the code error in the api" + assert detect_convo_room(content) == "technical" + + def test_planning_room(self): + content = "We need to plan the roadmap for the next sprint and set milestone deadlines" + assert detect_convo_room(content) == "planning" + + def test_architecture_room(self): + content = "The architecture uses a service layer with component interface and module design" + assert detect_convo_room(content) == "architecture" + + def test_decisions_room(self): + content = "We decided to switch and migrated to the new framework after we chose it" + assert detect_convo_room(content) == "decisions" + + def test_general_fallback(self): + content = "Hello, how are you doing today? The weather is nice." + assert detect_convo_room(content) == "general" + + +class TestScanConvos: + def test_scan_finds_txt_and_md(self, tmp_path): + (tmp_path / "chat.txt").write_text("hello", encoding="utf-8") + (tmp_path / "notes.md").write_text("world", encoding="utf-8") + (tmp_path / "image.png").write_bytes(b"fake") + files = scan_convos(str(tmp_path)) + extensions = {f.suffix for f in files} + assert ".txt" in extensions + assert ".md" in extensions + assert ".png" not in extensions + + def test_scan_skips_git_dir(self, tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config.txt").write_text("git stuff", encoding="utf-8") + (tmp_path / "chat.txt").write_text("hello", encoding="utf-8") + files = scan_convos(str(tmp_path)) + assert len(files) == 1 + + def test_scan_skips_meta_json(self, tmp_path): + (tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8") + (tmp_path / "chat.json").write_text("{}", encoding="utf-8") + files = scan_convos(str(tmp_path)) + names = [f.name for f in files] + assert "chat.json" in names + assert "chat.meta.json" not in names + + def test_scan_empty_dir(self, tmp_path): + files = scan_convos(str(tmp_path)) + assert files == [] diff --git a/tests/test_entity_detector.py b/tests/test_entity_detector.py new file mode 100644 index 0000000..91f0e29 --- /dev/null +++ b/tests/test_entity_detector.py @@ -0,0 +1,380 @@ +"""Tests for mempalace.entity_detector.""" + +import os +from unittest.mock import patch + +from mempalace.entity_detector import ( + PROSE_EXTENSIONS, + STOPWORDS, + _print_entity_list, + classify_entity, + confirm_entities, + detect_entities, + extract_candidates, + scan_for_detection, + score_entity, +) + + +# ── extract_candidates ────────────────────────────────────────────────── + + +def test_extract_candidates_finds_frequent_names(): + text = "Riley said hello. Riley laughed. Riley smiled. Riley waved." + result = extract_candidates(text) + assert "Riley" in result + assert result["Riley"] >= 3 + + +def test_extract_candidates_ignores_stopwords(): + # "The" appears many times but is a stopword + text = "The The The The The The" + result = extract_candidates(text) + assert "The" not in result + + +def test_extract_candidates_requires_min_frequency(): + text = "Riley said hi. Devon waved." + result = extract_candidates(text) + # Each name appears only once, below the threshold of 3 + assert "Riley" not in result + assert "Devon" not in result + + +def test_extract_candidates_finds_multi_word_names(): + # Multi-word names need 3+ occurrences and no stopwords + text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules." + result = extract_candidates(text) + assert "Claude Code" in result + + +def test_extract_candidates_empty_text(): + result = extract_candidates("") + assert result == {} + + +# ── score_entity ──────────────────────────────────────────────────────── + + +def test_score_entity_person_verbs(): + text = "Riley said hello. Riley asked why. Riley told me." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + assert len(result["person_signals"]) > 0 + + +def test_score_entity_project_verbs(): + text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + assert len(result["project_signals"]) > 0 + + +def test_score_entity_dialogue_markers(): + text = "Riley: Hey, how are you?\nRiley: I'm fine." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + + +def test_score_entity_code_ref(): + text = "Check out ChromaDB.py for details. Also ChromaDB.js is good." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + + +def test_score_entity_no_signals(): + text = "Nothing interesting here at all." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] == 0 + assert result["project_score"] == 0 + + +# ── classify_entity ───────────────────────────────────────────────────── + + +def test_classify_entity_no_signals_gives_uncertain(): + scores = { + "person_score": 0, + "project_score": 0, + "person_signals": [], + "project_signals": [], + } + result = classify_entity("Foo", 10, scores) + assert result["type"] == "uncertain" + assert result["name"] == "Foo" + + +def test_classify_entity_strong_project(): + scores = { + "person_score": 0, + "project_score": 10, + "person_signals": [], + "project_signals": ["project verb (5x)", "code file reference (2x)"], + } + result = classify_entity("ChromaDB", 5, scores) + assert result["type"] == "project" + + +def test_classify_entity_strong_person_needs_two_signal_types(): + scores = { + "person_score": 10, + "project_score": 0, + "person_signals": [ + "dialogue marker (3x)", + "'Riley ...' action (4x)", + ], + "project_signals": [], + } + result = classify_entity("Riley", 8, scores) + assert result["type"] == "person" + + +def test_classify_entity_pronoun_only_is_uncertain(): + scores = { + "person_score": 8, + "project_score": 0, + "person_signals": ["pronoun nearby (4x)"], + "project_signals": [], + } + result = classify_entity("Riley", 5, scores) + assert result["type"] == "uncertain" + + +def test_classify_entity_mixed_signals(): + scores = { + "person_score": 5, + "project_score": 5, + "person_signals": ["pronoun nearby (2x)"], + "project_signals": ["project verb (2x)"], + } + result = classify_entity("Lantern", 5, scores) + assert result["type"] == "uncertain" + assert "mixed signals" in result["signals"][-1] + + +# ── detect_entities (integration) ─────────────────────────────────────── + + +def test_detect_entities_with_person_file(tmp_path): + f = tmp_path / "notes.txt" + content = "\n".join( + [ + "Riley said hello today.", + "Riley asked about the project.", + "Riley told me she was happy.", + "Riley: I think we should go.", + "Hey Riley, thanks for the help.", + "Riley laughed and smiled.", + "Riley decided to join.", + "Riley pushed the change.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Riley" in all_names + + +def test_detect_entities_with_project_file(tmp_path): + f = tmp_path / "readme.txt" + # "ChromaDB" has uppercase+lowercase mix but extract_candidates looks + # for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex. + # Use "Lantern" which matches the capitalized-word pattern. + content = "\n".join( + [ + "The Lantern project is great.", + "Building Lantern was fun.", + "We deployed Lantern today.", + "Install Lantern with pip install Lantern.", + "Check Lantern.py for the source.", + "Lantern v2 is faster.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Lantern" in all_names + + +def test_detect_entities_empty_files(tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + result = detect_entities([f]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_handles_missing_file(tmp_path): + missing = tmp_path / "nonexistent.txt" + result = detect_entities([missing]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_respects_max_files(tmp_path): + files = [] + for i in range(5): + f = tmp_path / f"file{i}.txt" + f.write_text("Riley said hello. " * 10) + files.append(f) + # max_files=2 should only read 2 files + result = detect_entities(files, max_files=2) + # Should still work without error + assert isinstance(result, dict) + + +# ── scan_for_detection ────────────────────────────────────────────────── + + +def test_scan_for_detection_finds_prose(tmp_path): + (tmp_path / "notes.md").write_text("hello") + (tmp_path / "data.txt").write_text("world") + (tmp_path / "code.py").write_text("import os") + files = scan_for_detection(str(tmp_path)) + extensions = {os.path.splitext(str(f))[1] for f in files} + # Prose files should be found + assert ".md" in extensions or ".txt" in extensions + + +def test_scan_for_detection_skips_git_dir(tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config.txt").write_text("git config") + (tmp_path / "readme.md").write_text("hello") + files = scan_for_detection(str(tmp_path)) + file_strs = [str(f) for f in files] + assert not any(".git" in f for f in file_strs) + + +# ── module-level constants ────────────────────────────────────────────── + + +def test_stopwords_contains_common_words(): + assert "the" in STOPWORDS + assert "import" in STOPWORDS + assert "class" in STOPWORDS + + +def test_prose_extensions(): + assert ".txt" in PROSE_EXTENSIONS + assert ".md" in PROSE_EXTENSIONS + + +# ── _print_entity_list ───────────────────────────────────────────────── + + +def test_print_entity_list_with_entities(capsys): + entities = [ + {"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]}, + {"name": "Bob", "confidence": 0.5, "signals": []}, + ] + _print_entity_list(entities, "PEOPLE") + out = capsys.readouterr().out + assert "PEOPLE" in out + assert "Alice" in out + assert "Bob" in out + + +def test_print_entity_list_empty(capsys): + _print_entity_list([], "PEOPLE") + out = capsys.readouterr().out + assert "none detected" in out + + +# ── confirm_entities ─────────────────────────────────────────────────── + + +def test_confirm_entities_yes_mode(): + detected = { + "people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}], + "projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}], + "uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}], + } + result = confirm_entities(detected, yes=True) + assert result["people"] == ["Alice"] + assert result["projects"] == ["Acme"] + + +def test_confirm_entities_accept_all(): + detected = { + "people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}], + "projects": [], + "uncertain": [], + } + with patch("builtins.input", side_effect=["", "n"]): + result = confirm_entities(detected, yes=False) + assert "Alice" in result["people"] + + +def test_confirm_entities_edit_reclassify_uncertain(): + detected = { + "people": [], + "projects": [], + "uncertain": [ + {"name": "Foo", "confidence": 0.4, "signals": ["test"]}, + {"name": "Bar", "confidence": 0.4, "signals": ["test"]}, + ], + } + with patch( + "builtins.input", + side_effect=[ + "edit", # choice + "p", # Foo -> person + "s", # Bar -> skip + "", # no removals from people + "", # no removals from projects + "n", # don't add missing + ], + ): + result = confirm_entities(detected, yes=False) + assert "Foo" in result["people"] + assert "Bar" not in result["people"] + assert "Bar" not in result["projects"] + + +def test_confirm_entities_add_mode(): + detected = { + "people": [], + "projects": [], + "uncertain": [], + } + with patch( + "builtins.input", + side_effect=[ + "add", # choice = add + "NewPerson", # name + "p", # person + "NewProj", # name + "r", # project + "", # stop adding + ], + ): + result = confirm_entities(detected, yes=False) + assert "NewPerson" in result["people"] + assert "NewProj" in result["projects"] + + +# ── scan_for_detection fallback ──────────────────────────────────────── + + +def test_scan_for_detection_fallback_to_all_readable(tmp_path): + """When fewer than 3 prose files, falls back to include all readable files.""" + (tmp_path / "one.md").write_text("hello") + (tmp_path / "two.txt").write_text("world") + # Only 2 prose files, so it should also include code files + (tmp_path / "code.py").write_text("import os") + (tmp_path / "app.js").write_text("console.log()") + files = scan_for_detection(str(tmp_path)) + extensions = {os.path.splitext(str(f))[1] for f in files} + assert ".py" in extensions or ".js" in extensions + + +def test_scan_for_detection_max_files(tmp_path): + """Caps to max_files.""" + for i in range(20): + (tmp_path / f"note{i}.md").write_text(f"content {i}") + files = scan_for_detection(str(tmp_path), max_files=5) + assert len(files) <= 5 diff --git a/tests/test_entity_registry.py b/tests/test_entity_registry.py new file mode 100644 index 0000000..b92bf84 --- /dev/null +++ b/tests/test_entity_registry.py @@ -0,0 +1,313 @@ +"""Tests for mempalace.entity_registry.""" + +from unittest.mock import patch + +from mempalace.entity_registry import ( + COMMON_ENGLISH_WORDS, + PERSON_CONTEXT_PATTERNS, + EntityRegistry, +) + + +# ── COMMON_ENGLISH_WORDS ──────────────────────────────────────────────── + + +def test_common_english_words_has_expected_entries(): + assert "ever" in COMMON_ENGLISH_WORDS + assert "grace" in COMMON_ENGLISH_WORDS + assert "will" in COMMON_ENGLISH_WORDS + assert "may" in COMMON_ENGLISH_WORDS + assert "monday" in COMMON_ENGLISH_WORDS + + +def test_common_english_words_is_lowercase(): + for word in COMMON_ENGLISH_WORDS: + assert word == word.lower(), f"{word} should be lowercase" + + +# ── PERSON_CONTEXT_PATTERNS ───────────────────────────────────────────── + + +def test_person_context_patterns_is_nonempty(): + assert len(PERSON_CONTEXT_PATTERNS) > 0 + + +# ── EntityRegistry creation and empty state ───────────────────────────── + + +def test_load_from_nonexistent_dir(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + assert registry.people == {} + assert registry.projects == [] + assert registry.mode == "personal" + assert registry.ambiguous_flags == [] + + +def test_save_and_load_roundtrip(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["MemPalace"], + ) + # Load again from same dir + loaded = EntityRegistry.load(config_dir=tmp_path) + assert loaded.mode == "work" + assert "Alice" in loaded.people + assert "MemPalace" in loaded.projects + + +def test_save_creates_file(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.save() + assert (tmp_path / "entity_registry.json").exists() + + +# ── seed ──────────────────────────────────────────────────────────────── + + +def test_seed_registers_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=["MemPalace"], + ) + assert "Riley" in registry.people + assert "Devon" in registry.people + assert registry.people["Riley"]["relationship"] == "daughter" + assert registry.people["Riley"]["source"] == "onboarding" + assert registry.people["Riley"]["confidence"] == 1.0 + + +def test_seed_registers_projects(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["Acme", "Widget"]) + assert registry.projects == ["Acme", "Widget"] + + +def test_seed_sets_mode(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="combo", people=[], projects=[]) + assert registry.mode == "combo" + + +def test_seed_flags_ambiguous_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Grace", "relationship": "friend", "context": "personal"}, + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + ], + projects=[], + ) + assert "grace" in registry.ambiguous_flags + # Riley is not a common English word + assert "riley" not in registry.ambiguous_flags + + +def test_seed_with_aliases(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + assert "Maxwell" in registry.people + assert "Max" in registry.people + assert registry.people["Max"].get("canonical") == "Maxwell" + + +def test_seed_skips_empty_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "", "relationship": "", "context": "personal"}], + projects=[], + ) + assert len(registry.people) == 0 + + +# ── lookup ────────────────────────────────────────────────────────────── + + +def test_lookup_known_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Riley") + assert result["type"] == "person" + assert result["confidence"] == 1.0 + assert result["name"] == "Riley" + + +def test_lookup_known_project(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["MemPalace"]) + result = registry.lookup("MemPalace") + assert result["type"] == "project" + assert result["confidence"] == 1.0 + + +def test_lookup_unknown_word(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + result = registry.lookup("Xyzzy") + assert result["type"] == "unknown" + assert result["confidence"] == 0.0 + + +def test_lookup_case_insensitive(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("riley") + assert result["type"] == "person" + + +def test_lookup_alias(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + result = registry.lookup("Max") + assert result["type"] == "person" + + +# ── disambiguation ────────────────────────────────────────────────────── + + +def test_lookup_ambiguous_word_as_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Grace", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Grace", context="I went with Grace today") + assert result["type"] == "person" + + +def test_lookup_ambiguous_word_as_concept(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Ever", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Ever", context="have you ever tried this") + assert result["type"] == "concept" + + +# ── research (Wikipedia) — mocked ────────────────────────────────────── + + +def test_research_caches_result(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is an Irish given name.", + "wiki_title": "Saoirse", + } + + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + result = registry.research("Saoirse", auto_confirm=True) + assert result["inferred_type"] == "person" + + # Second call should use cache, not call Wikipedia again + with patch( + "mempalace.entity_registry._wikipedia_lookup", + side_effect=AssertionError("should not be called"), + ): + cached = registry.research("Saoirse") + assert cached["inferred_type"] == "person" + + +def test_confirm_research_adds_to_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is a name", + "wiki_title": "Saoirse", + } + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + registry.research("Saoirse", auto_confirm=False) + + registry.confirm_research("Saoirse", entity_type="person", relationship="friend") + assert "Saoirse" in registry.people + assert registry.people["Saoirse"]["source"] == "wiki" + + +# ── extract_people_from_query ─────────────────────────────────────────── + + +def test_extract_people_from_query(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=[], + ) + found = registry.extract_people_from_query("What did Riley say about the weather?") + assert "Riley" in found + assert "Devon" not in found + + +# ── extract_unknown_candidates ────────────────────────────────────────── + + +def test_extract_unknown_candidates(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + unknowns = registry.extract_unknown_candidates("Saoirse went to the store") + assert "Saoirse" in unknowns + + +def test_extract_unknown_candidates_skips_known(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + unknowns = registry.extract_unknown_candidates("Riley went to the store") + assert "Riley" not in unknowns + + +# ── summary ───────────────────────────────────────────────────────────── + + +def test_summary(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + ) + s = registry.summary() + assert "personal" in s + assert "Riley" in s + assert "MemPalace" in s diff --git a/tests/test_general_extractor.py b/tests/test_general_extractor.py new file mode 100644 index 0000000..0f5d46c --- /dev/null +++ b/tests/test_general_extractor.py @@ -0,0 +1,248 @@ +"""Tests for mempalace.general_extractor.""" + +from mempalace.general_extractor import ( + ALL_MARKERS, + NEGATIVE_WORDS, + POSITIVE_WORDS, + _extract_prose, + _get_sentiment, + _has_resolution, + _is_code_line, + _score_markers, + _split_into_segments, + extract_memories, +) + + +# ── extract_memories — empty / no markers ─────────────────────────────── + + +def test_extract_memories_empty_text(): + result = extract_memories("") + assert result == [] + + +def test_extract_memories_no_markers(): + result = extract_memories("The quick brown fox jumped over the lazy dog.") + assert result == [] + + +def test_extract_memories_short_text_skipped(): + # Paragraphs shorter than 20 chars are skipped + result = extract_memories("ok sure") + assert result == [] + + +# ── extract_memories — decision markers ───────────────────────────────── + + +def test_extract_memories_decision(): + text = ( + "We decided to go with PostgreSQL instead of MySQL " + "because the performance was better for our use case. " + "The trade-off was more complexity in setup." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "decision" for m in result) + + +# ── extract_memories — preference markers ─────────────────────────────── + + +def test_extract_memories_preference(): + text = ( + "I prefer using snake_case in Python code. " + "Please always use type hints. " + "Never use wildcard imports." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "preference" for m in result) + + +# ── extract_memories — milestone markers ──────────────────────────────── + + +def test_extract_memories_milestone(): + text = ( + "It finally works! After three days of debugging, " + "I figured out the issue. The breakthrough was realizing " + "the config file was cached. Got it working at 2am." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "milestone" for m in result) + + +# ── extract_memories — problem markers ────────────────────────────────── + + +def test_extract_memories_problem(): + text = ( + "There's a critical bug in the auth module. " + "The error keeps crashing the server. " + "The root cause was a missing null check. " + "The problem is that tokens expire silently." + ) + result = extract_memories(text) + assert len(result) >= 1 + types = {m["memory_type"] for m in result} + assert "problem" in types or "milestone" in types # resolved problems become milestones + + +# ── extract_memories — emotional markers ──────────────────────────────── + + +def test_extract_memories_emotional(): + text = ( + "I feel so proud of what we built together. " + "I love working on this project, it makes me happy. " + "I'm grateful for the team and the beautiful code we wrote." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "emotional" for m in result) + + +# ── extract_memories — chunk_index ────────────────────────────────────── + + +def test_extract_memories_chunk_index_increments(): + text = ( + "We decided to use React because it fits our team.\n\n" + "I prefer functional components always.\n\n" + "It works! We finally shipped the v1.0 release." + ) + result = extract_memories(text) + if len(result) >= 2: + indices = [m["chunk_index"] for m in result] + assert indices == list(range(len(result))) + + +# ── _score_markers ────────────────────────────────────────────────────── + + +def test_score_markers_with_matches(): + score, keywords = _score_markers( + "we decided to go with postgres because it is faster", + ALL_MARKERS["decision"], + ) + assert score > 0 + assert len(keywords) > 0 + + +def test_score_markers_no_matches(): + score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"]) + assert score == 0.0 + + +# ── _get_sentiment ────────────────────────────────────────────────────── + + +def test_get_sentiment_positive(): + assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive" + + +def test_get_sentiment_negative(): + assert _get_sentiment("This bug caused a crash and total failure") == "negative" + + +def test_get_sentiment_neutral(): + assert _get_sentiment("The meeting is at three") == "neutral" + + +# ── _has_resolution ───────────────────────────────────────────────────── + + +def test_has_resolution_true(): + assert _has_resolution("I fixed the auth bug and it works now") is True + + +def test_has_resolution_false(): + assert _has_resolution("The server keeps crashing") is False + + +# ── _is_code_line ─────────────────────────────────────────────────────── + + +def test_is_code_line_detects_code(): + assert _is_code_line(" import os") is True + assert _is_code_line(" $ pip install flask") is True + assert _is_code_line(" ```python") is True + + +def test_is_code_line_allows_prose(): + assert _is_code_line("This is a regular sentence about coding.") is False + assert _is_code_line("") is False + + +# ── _extract_prose ────────────────────────────────────────────────────── + + +def test_extract_prose_strips_code_blocks(): + text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye" + result = _extract_prose(text) + assert "import os" not in result + assert "Hello world" in result + assert "Goodbye" in result + + +def test_extract_prose_returns_original_if_all_code(): + text = "import os\nfrom sys import argv" + result = _extract_prose(text) + # Falls back to original text if nothing left + assert len(result) > 0 + + +# ── _split_into_segments ─────────────────────────────────────────────── + + +def test_split_into_segments_by_paragraph(): + text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph." + result = _split_into_segments(text) + assert len(result) == 3 + + +def test_split_into_segments_by_turns(): + lines = [] + for i in range(5): + lines.append(f"Human: Question {i}") + lines.append(f"Assistant: Answer {i}") + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 3 # turn-based splitting should fire + + +def test_split_into_segments_single_block(): + # Many lines without double-newline produces chunked segments + lines = [f"Line {i} of the document" for i in range(30)] + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 1 + + +# ── ALL_MARKERS constant ─────────────────────────────────────────────── + + +def test_all_markers_has_five_types(): + assert set(ALL_MARKERS.keys()) == { + "decision", + "preference", + "milestone", + "problem", + "emotional", + } + + +# ── POSITIVE_WORDS / NEGATIVE_WORDS ──────────────────────────────────── + + +def test_positive_words(): + assert "happy" in POSITIVE_WORDS + assert "proud" in POSITIVE_WORDS + + +def test_negative_words(): + assert "bug" in NEGATIVE_WORDS + assert "crash" in NEGATIVE_WORDS diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py new file mode 100644 index 0000000..5a1870e --- /dev/null +++ b/tests/test_hooks_cli.py @@ -0,0 +1,420 @@ +import contextlib +import io +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from mempalace.hooks_cli import ( + SAVE_INTERVAL, + STOP_BLOCK_REASON, + PRECOMPACT_BLOCK_REASON, + _count_human_messages, + _log, + _maybe_auto_ingest, + _parse_harness_input, + _sanitize_session_id, + hook_stop, + hook_session_start, + hook_precompact, + run_hook, +) + + +# --- _sanitize_session_id --- + + +def test_sanitize_normal_id(): + assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ" + + +def test_sanitize_strips_dangerous_chars(): + assert _sanitize_session_id("../../etc/passwd") == "etcpasswd" + + +def test_sanitize_empty_returns_unknown(): + assert _sanitize_session_id("") == "unknown" + assert _sanitize_session_id("!!!") == "unknown" + + +# --- _count_human_messages --- + + +def _write_transcript(path: Path, entries: list[dict]): + with open(path, "w", encoding="utf-8") as f: + for entry in entries: + f.write(json.dumps(entry) + "\n") + + +def test_count_human_messages_basic(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "hello"}}, + {"message": {"role": "assistant", "content": "hi"}}, + {"message": {"role": "user", "content": "bye"}}, + ], + ) + assert _count_human_messages(str(transcript)) == 2 + + +def test_count_skips_command_messages(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "status"}}, + {"message": {"role": "user", "content": "real question"}}, + ], + ) + assert _count_human_messages(str(transcript)) == 1 + + +def test_count_handles_list_content(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, + { + "message": { + "role": "user", + "content": [{"type": "text", "text": "x"}], + } + }, + ], + ) + assert _count_human_messages(str(transcript)) == 1 + + +def test_count_missing_file(): + assert _count_human_messages("/nonexistent/path.jsonl") == 0 + + +def test_count_empty_file(tmp_path): + transcript = tmp_path / "t.jsonl" + transcript.write_text("") + assert _count_human_messages(str(transcript)) == 0 + + +def test_count_malformed_json_lines(tmp_path): + transcript = tmp_path / "t.jsonl" + transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n') + assert _count_human_messages(str(transcript)) == 1 + + +# --- hook_stop --- + + +def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None): + """Run a hook and capture its JSON stdout output.""" + import io + + buf = io.StringIO() + patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))] + if state_dir: + patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir)) + with contextlib.ExitStack() as stack: + for p in patches: + stack.enter_context(p) + hook_fn(data, harness) + return json.loads(buf.getvalue()) + + +def test_stop_hook_passthrough_when_active(tmp_path): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": True, "transcript_path": ""}, + state_dir=tmp_path, + ) + assert result == {} + + +def test_stop_hook_passthrough_when_active_string(tmp_path): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": "true", "transcript_path": ""}, + state_dir=tmp_path, + ) + assert result == {} + + +def test_stop_hook_passthrough_below_interval(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)], + ) + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result == {} + + +def test_stop_hook_blocks_at_interval(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + assert result["reason"] == STOP_BLOCK_REASON + + +def test_stop_hook_tracks_save_point(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)} + + # First call blocks + result = _capture_hook_output(hook_stop, data, state_dir=tmp_path) + assert result["decision"] == "block" + + # Second call with same count passes through (already saved) + result = _capture_hook_output(hook_stop, data, state_dir=tmp_path) + assert result == {} + + +# --- hook_session_start --- + + +def test_session_start_passes_through(tmp_path): + result = _capture_hook_output( + hook_session_start, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result == {} + + +# --- hook_precompact --- + + +def test_precompact_always_blocks(tmp_path): + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + assert result["reason"] == PRECOMPACT_BLOCK_REASON + + +# --- _log --- + + +def test_log_writes_to_hook_log(tmp_path): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _log("test message") + log_path = tmp_path / "hook.log" + assert log_path.is_file() + content = log_path.read_text() + assert "test message" in content + + +def test_log_oserror_is_silenced(tmp_path): + """_log should not raise if the directory cannot be created.""" + with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")): + # Should not raise + _log("this will fail silently") + + +# --- _maybe_auto_ingest --- + + +def test_maybe_auto_ingest_no_env(tmp_path): + """Without MEMPAL_DIR set, does nothing.""" + with patch.dict("os.environ", {}, clear=True): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _maybe_auto_ingest() # should not raise + + +def test_maybe_auto_ingest_with_env(tmp_path): + """With MEMPAL_DIR set to a valid directory, spawns subprocess.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen: + _maybe_auto_ingest() + mock_popen.assert_called_once() + + +def test_maybe_auto_ingest_oserror(tmp_path): + """OSError during subprocess spawn is silenced.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")): + _maybe_auto_ingest() # should not raise + + +# --- _parse_harness_input --- + + +def test_parse_harness_input_unknown(): + """Unknown harness should sys.exit(1).""" + with pytest.raises(SystemExit) as exc_info: + _parse_harness_input({"session_id": "test"}, "unknown-harness") + assert exc_info.value.code == 1 + + +def test_parse_harness_input_valid(): + result = _parse_harness_input( + {"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"}, + "claude-code", + ) + assert result["session_id"] == "abc-123" + assert result["stop_hook_active"] is True + + +# --- hook_stop with OSError on write --- + + +def test_stop_hook_oserror_on_last_save_read(tmp_path): + """When last_save_file has invalid content, falls back to 0.""" + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + # Write invalid content to last save file + (tmp_path / "test_last_save").write_text("not_a_number") + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +def test_stop_hook_oserror_on_write(tmp_path): + """When write to last_save_file fails, hook still outputs correctly.""" + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + + def bad_write_text(*args, **kwargs): + raise OSError("disk full") + + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch.object(Path, "write_text", bad_write_text): + result = _capture_hook_output( + hook_stop, + { + "session_id": "test", + "stop_hook_active": False, + "transcript_path": str(transcript), + }, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- hook_precompact with MEMPAL_DIR --- + + +def test_precompact_with_mempal_dir(tmp_path): + """Precompact runs subprocess.run when MEMPAL_DIR is set.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run") as mock_run: + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + mock_run.assert_called_once() + + +def test_precompact_with_mempal_dir_oserror(tmp_path): + """Precompact handles OSError from subprocess gracefully.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")): + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- run_hook --- + + +def test_run_hook_dispatches_session_start(tmp_path): + """run_hook reads stdin JSON and dispatches to correct handler.""" + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_stop(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)] + ) + stdin_data = json.dumps( + { + "session_id": "run-test", + "stop_hook_active": False, + "transcript_path": str(transcript), + } + ) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("stop", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_precompact(tmp_path): + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("precompact", "claude-code") + mock_output.assert_called_once() + call_args = mock_output.call_args[0][0] + assert call_args["decision"] == "block" + + +def test_run_hook_unknown_hook(): + stdin_data = json.dumps({"session_id": "test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with pytest.raises(SystemExit) as exc_info: + run_hook("nonexistent", "claude-code") + assert exc_info.value.code == 1 + + +def test_run_hook_invalid_json(tmp_path): + """Invalid stdin JSON should not crash — falls back to empty dict.""" + with patch("sys.stdin", io.StringIO("not valid json")): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) diff --git a/tests/test_instructions_cli.py b/tests/test_instructions_cli.py new file mode 100644 index 0000000..c99ed14 --- /dev/null +++ b/tests/test_instructions_cli.py @@ -0,0 +1,45 @@ +"""Tests for mempalace.instructions_cli — instruction text output.""" + +from unittest.mock import patch + +import pytest + +from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions + + +def test_run_instructions_valid_name(capsys): + """Valid name prints the .md file content.""" + name = "init" + expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text() + run_instructions(name) + captured = capsys.readouterr() + assert captured.out.strip() == expected.strip() + + +def test_run_instructions_all_available(capsys): + """Every name in AVAILABLE should succeed without error.""" + for name in AVAILABLE: + run_instructions(name) + out = capsys.readouterr().out + assert len(out) > 0 + + +def test_run_instructions_invalid_name(capsys): + """Invalid name should sys.exit(1) and print error to stderr.""" + with pytest.raises(SystemExit) as exc_info: + run_instructions("nonexistent") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Unknown instructions: nonexistent" in captured.err + assert "Available:" in captured.err + + +def test_run_instructions_missing_md_file(capsys, tmp_path): + """If the .md file is missing on disk, should sys.exit(1).""" + with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path): + with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]): + with pytest.raises(SystemExit) as exc_info: + run_instructions("fakecmd") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Instructions file not found" in captured.err diff --git a/tests/test_knowledge_graph_extra.py b/tests/test_knowledge_graph_extra.py new file mode 100644 index 0000000..29605bb --- /dev/null +++ b/tests/test_knowledge_graph_extra.py @@ -0,0 +1,105 @@ +"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship.""" + +import pytest + +from mempalace.knowledge_graph import KnowledgeGraph + + +@pytest.fixture +def kg(tmp_path): + return KnowledgeGraph(db_path=str(tmp_path / "kg.db")) + + +class TestSeedFromEntityFacts: + def test_seed_person_with_partner(self, kg): + facts = { + "alice": { + "full_name": "Alice Smith", + "type": "person", + "gender": "female", + "partner": "bob", + "relationship": "husband", + } + } + kg.seed_from_entity_facts(facts) + stats = kg.stats() + assert stats["entities"] >= 1 + results = kg.query_entity("Alice Smith", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "married_to" in predicates + assert "is_partner_of" in predicates + + def test_seed_child(self, kg): + facts = { + "max": { + "full_name": "Max", + "type": "person", + "birthday": "2015-04-01", + "parent": "alice", + "relationship": "daughter", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Max", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "child_of" in predicates + assert "is_child_of" in predicates + + def test_seed_sibling(self, kg): + facts = { + "emma": { + "full_name": "Emma", + "type": "person", + "relationship": "brother", + "sibling": "max", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Emma", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "is_sibling_of" in predicates + + def test_seed_dog(self, kg): + facts = { + "rex": { + "full_name": "Rex", + "type": "animal", + "relationship": "dog", + "owner": "alice", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Rex", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "is_pet_of" in predicates + + def test_seed_with_interests(self, kg): + facts = { + "max": { + "full_name": "Max", + "type": "person", + "interests": ["swimming", "chess"], + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Max", direction="outgoing") + objects = {r["object"] for r in results if r["predicate"] == "loves"} + assert "Swimming" in objects + assert "Chess" in objects + + def test_seed_minimal_facts(self, kg): + """Facts with no relationships just create entities.""" + facts = {"bob": {"full_name": "Bob"}} + kg.seed_from_entity_facts(facts) + stats = kg.stats() + assert stats["entities"] >= 1 + + +class TestQueryRelationshipWithTime: + def test_query_relationship_with_as_of(self, kg): + kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31") + kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01") + results = kg.query_relationship("works_at", as_of="2023-06-01") + objects = [r["object"] for r in results] + assert "Acme" in objects + assert "NewCo" not in objects diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000..46b60e9 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,719 @@ +"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack.""" + +import os +from unittest.mock import MagicMock, patch + +from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack + + +# ── Layer0 — with identity file ───────────────────────────────────────── + + +def test_layer0_reads_identity_file(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas, a personal AI assistant for Alice.") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert "Atlas" in text + assert "Alice" in text + + +def test_layer0_caches_text(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("Hello world") + layer = Layer0(identity_path=str(identity_file)) + first = layer.render() + identity_file.write_text("Changed content") + second = layer.render() + assert first == second + assert second == "Hello world" + + +def test_layer0_missing_file_returns_default(tmp_path): + missing = str(tmp_path / "nonexistent.txt") + layer = Layer0(identity_path=missing) + text = layer.render() + assert "No identity configured" in text + assert "identity.txt" in text + + +def test_layer0_token_estimate(tmp_path): + identity_file = tmp_path / "identity.txt" + content = "A" * 400 + identity_file.write_text(content) + layer = Layer0(identity_path=str(identity_file)) + estimate = layer.token_estimate() + assert estimate == 100 + + +def test_layer0_token_estimate_empty(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("") + layer = Layer0(identity_path=str(identity_file)) + assert layer.token_estimate() == 0 + + +def test_layer0_strips_whitespace(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text(" Hello world \n\n") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert text == "Hello world" + + +def test_layer0_default_path(): + layer = Layer0() + expected = os.path.expanduser("~/.mempalace/identity.txt") + assert layer.path == expected + + +# ── Layer1 — mocked chromadb ──────────────────────────────────────────── + + +def _mock_chromadb_for_layer(docs, metas, monkeypatch=None): + """Return a mock PersistentClient whose collection.get returns docs/metas.""" + mock_col = MagicMock() + # First batch returns data, second batch returns empty (end of pagination) + mock_col.get.side_effect = [ + {"documents": docs, "metadatas": metas}, + {"documents": [], "metadatas": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + return mock_client + + +def test_layer1_no_palace(): + """Layer1 returns helpful message when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer1(palace_path="/nonexistent/palace") + result = layer.generate() + assert "No palace found" in result or "No memories" in result + + +def test_layer1_generates_essential_story(): + docs = [ + "Important memory about project decisions", + "Key architectural choice for the backend", + ] + metas = [ + {"room": "decisions", "source_file": "meeting.txt", "importance": 5}, + {"room": "architecture", "source_file": "design.txt", "importance": 4}, + ] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + assert "project decisions" in result + + +def test_layer1_empty_palace(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "No memories" in result + + +def test_layer1_with_wing_filter(): + docs = ["Memory about project X"] + metas = [{"room": "general", "source_file": "x.txt", "importance": 3}] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake", wing="project_x") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + # Verify wing filter was passed + call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1] + assert call_kwargs.get("where") == {"wing": "project_x"} + + +def test_layer1_truncates_long_snippets(): + docs = ["A" * 300] + metas = [{"room": "general", "source_file": "long.txt"}] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "..." in result + + +def test_layer1_respects_max_chars(): + """L1 stops adding entries once MAX_CHARS is reached.""" + docs = [f"Memory number {i} with substantial content padding here" for i in range(30)] + metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + layer.MAX_CHARS = 200 # Very low cap to trigger truncation + result = layer.generate() + + assert "more in L3 search" in result + + +def test_layer1_importance_from_various_keys(): + """Layer1 tries importance, emotional_weight, weight keys.""" + docs = ["mem1", "mem2", "mem3"] + metas = [ + {"room": "r", "emotional_weight": 5}, + {"room": "r", "weight": 1}, + {"room": "r"}, # no weight key, defaults to 3 + ] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + + +def test_layer1_batch_exception_breaks(): + """If col.get raises on a batch, loop breaks gracefully.""" + mock_col = MagicMock() + mock_col.get.side_effect = [ + {"documents": ["doc1"], "metadatas": [{"room": "r"}]}, + RuntimeError("batch error"), + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + + +# ── Layer2 — mocked chromadb ──────────────────────────────────────────── + + +def test_layer2_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer2(palace_path="/nonexistent/palace") + result = layer.retrieve(wing="test") + assert "No palace found" in result + + +def test_layer2_retrieve_with_wing(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Some memory about the project"], + "metadatas": [{"room": "backend", "source_file": "notes.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="project") + + assert "ON-DEMAND" in result + assert "memory about the project" in result + + +def test_layer2_retrieve_with_room(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Backend architecture notes"], + "metadatas": [{"room": "architecture", "source_file": "arch.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(room="architecture") + + assert "ON-DEMAND" in result + + +def test_layer2_retrieve_wing_and_room(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Filtered result"], + "metadatas": [{"room": "backend", "source_file": "x.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="proj", room="backend") + + assert "ON-DEMAND" in result + call_kwargs = mock_col.get.call_args[1] + assert "$and" in call_kwargs.get("where", {}) + + +def test_layer2_retrieve_empty(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="missing") + + assert "No drawers found" in result + + +def test_layer2_retrieve_no_filter(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + layer.retrieve() + + # No where filter should be passed + call_kwargs = mock_col.get.call_args[1] + assert "where" not in call_kwargs + + +def test_layer2_retrieve_error(): + mock_col = MagicMock() + mock_col.get.side_effect = RuntimeError("db error") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="test") + + assert "Retrieval error" in result + + +def test_layer2_truncates_long_snippets(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["B" * 400], + "metadatas": [{"room": "r", "source_file": "s.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="test") + + assert "..." in result + + +# ── Layer3 — mocked chromadb ──────────────────────────────────────────── + + +def _mock_query_results(docs, metas, dists): + return { + "documents": [docs], + "metadatas": [metas], + "distances": [dists], + } + + +def test_layer3_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search("test query") + assert "No palace found" in result + + +def test_layer3_search_raw_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search_raw("test query") + assert result == [] + + +def test_layer3_search_with_results(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["Found this important memory"], + [{"wing": "project", "room": "backend", "source_file": "notes.txt"}], + [0.2], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("important") + + assert "SEARCH RESULTS" in result + assert "important memory" in result + assert "sim=0.8" in result + + +def test_layer3_search_no_results(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results([], [], []) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("nothing") + + assert "No results found" in result + + +def test_layer3_search_with_wing_filter(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "proj", "room": "r"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", wing="proj") + + call_kwargs = mock_col.query.call_args[1] + assert call_kwargs["where"] == {"wing": "proj"} + + +def test_layer3_search_with_room_filter(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "w", "room": "backend"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", room="backend") + + call_kwargs = mock_col.query.call_args[1] + assert call_kwargs["where"] == {"room": "backend"} + + +def test_layer3_search_with_wing_and_room(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "proj", "room": "backend"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", wing="proj", room="backend") + + call_kwargs = mock_col.query.call_args[1] + assert "$and" in call_kwargs["where"] + + +def test_layer3_search_error(): + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("search failed") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("q") + + assert "Search error" in result + + +def test_layer3_search_truncates_long_docs(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["C" * 400], + [{"wing": "w", "room": "r", "source_file": "s.txt"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("q") + + assert "..." in result + + +def test_layer3_search_raw_returns_dicts(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["doc text"], + [{"wing": "proj", "room": "backend", "source_file": "f.txt"}], + [0.3], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + hits = layer.search_raw("q") + + assert len(hits) == 1 + assert hits[0]["text"] == "doc text" + assert hits[0]["wing"] == "proj" + assert hits[0]["similarity"] == 0.7 + assert "metadata" in hits[0] + + +def test_layer3_search_raw_with_filters(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["doc"], + [{"wing": "w", "room": "r"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search_raw("q", wing="w", room="r") + + call_kwargs = mock_col.query.call_args[1] + assert "$and" in call_kwargs["where"] + + +def test_layer3_search_raw_error(): + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("fail") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search_raw("q") + + assert result == [] + + +# ── MemoryStack ───────────────────────────────────────────────────────── + + +def test_memory_stack_wake_up(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.wake_up() + + assert "Atlas" in result + # L1 will say no palace found + assert "No palace" in result or "No memories" in result + + +def test_memory_stack_wake_up_with_wing(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.wake_up(wing="my_project") + + assert stack.l1.wing == "my_project" + assert "Atlas" in result + + +def test_memory_stack_recall(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.recall(wing="test") + + assert "No palace found" in result + + +def test_memory_stack_search(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.search("test query") + + assert "No palace found" in result + + +def test_memory_stack_status(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.status() + + assert result["palace_path"] == "/nonexistent" + assert result["total_drawers"] == 0 + assert "L0_identity" in result + assert "L1_essential" in result + assert "L2_on_demand" in result + assert "L3_deep_search" in result + + +def test_memory_stack_status_with_palace(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + mock_col = MagicMock() + mock_col.count.return_value = 42 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + stack = MemoryStack( + palace_path="/fake", + identity_path=str(identity_file), + ) + result = stack.status() + + assert result["total_drawers"] == 42 + assert result["L0_identity"]["exists"] is True diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index cf37a27..24258a9 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -9,25 +9,26 @@ via monkeypatch to avoid touching real data. import json -def _patch_mcp_server(monkeypatch, config, palace_path, kg): +def _patch_mcp_server(monkeypatch, config, kg): """Patch the mcp_server module globals to use test fixtures.""" from mempalace import mcp_server - assert getattr(config, "palace_path", None) == palace_path, ( - f"config.palace_path ({getattr(config, 'palace_path', None)!r}) does not match palace_path fixture ({palace_path!r})" - ) monkeypatch.setattr(mcp_server, "_config", config) monkeypatch.setattr(mcp_server, "_kg", kg) def _get_collection(palace_path, create=False): - """Helper to get collection from test palace.""" + """Helper to get collection from test palace. + + Returns (client, collection) so callers can clean up the client + when they are done. + """ import chromadb client = chromadb.PersistentClient(path=palace_path) if create: - return client.get_or_create_collection("mempalace_drawers") - return client.get_collection("mempalace_drawers") + return client, client.get_or_create_collection("mempalace_drawers") + return client, client.get_collection("mempalace_drawers") # ── Protocol Layer ────────────────────────────────────────────────────── @@ -77,11 +78,12 @@ class TestHandleRequest: assert resp["error"]["code"] == -32601 def test_tools_call_dispatches(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import handle_request # Create a collection so status works - _get_collection(palace_path, create=True) + _client, _col = _get_collection(palace_path, create=True) + del _client resp = handle_request( { @@ -100,8 +102,9 @@ class TestHandleRequest: class TestReadTools: def test_status_empty_palace(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_status result = tool_status() @@ -109,7 +112,7 @@ class TestReadTools: assert result["wings"] == {} def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_status result = tool_status() @@ -118,7 +121,7 @@ class TestReadTools: assert "notes" in result["wings"] def test_list_wings(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_wings result = tool_list_wings() @@ -126,7 +129,7 @@ class TestReadTools: assert result["wings"]["notes"] == 1 def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms result = tool_list_rooms() @@ -135,7 +138,7 @@ class TestReadTools: assert "planning" in result["rooms"] def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms result = tool_list_rooms(wing="project") @@ -143,7 +146,7 @@ class TestReadTools: assert "planning" not in result["rooms"] def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_get_taxonomy result = tool_get_taxonomy() @@ -152,8 +155,7 @@ class TestReadTools: assert result["taxonomy"]["notes"]["planning"] == 1 def test_no_palace_returns_error(self, monkeypatch, config, kg): - config._file_config["palace_path"] = "/nonexistent/path" - _patch_mcp_server(monkeypatch, config, "/nonexistent/path", kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_status result = tool_status() @@ -165,7 +167,7 @@ class TestReadTools: class TestSearchTool: def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="JWT authentication tokens") @@ -176,14 +178,14 @@ class TestSearchTool: assert "JWT" in top["text"] or "authentication" in top["text"].lower() def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="planning", wing="notes") assert all(r["wing"] == "notes" for r in result["results"]) def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="database", room="backend") @@ -195,8 +197,9 @@ class TestSearchTool: class TestWriteTools: def test_add_drawer(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_add_drawer result = tool_add_drawer( @@ -210,8 +213,9 @@ class TestWriteTools: assert result["drawer_id"].startswith("drawer_test_wing_test_room_") def test_add_drawer_duplicate_detection(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_add_drawer content = "This is a unique test memory about Rust ownership and borrowing." @@ -219,11 +223,11 @@ class TestWriteTools: assert result1["success"] is True result2 = tool_add_drawer(wing="w", room="r", content=content) - assert result2["success"] is False - assert result2["reason"] == "duplicate" + assert result2["success"] is True + assert result2["reason"] == "already_exists" def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer result = tool_delete_drawer("drawer_proj_backend_aaa") @@ -231,14 +235,14 @@ class TestWriteTools: assert seeded_collection.count() == 3 def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer result = tool_delete_drawer("nonexistent_drawer") assert result["success"] is False def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_check_duplicate # Exact match text from seeded_collection should be flagged @@ -262,7 +266,7 @@ class TestWriteTools: class TestKGTools: def test_kg_add(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_kg_add result = tool_kg_add( @@ -274,14 +278,14 @@ class TestKGTools: assert result["success"] is True def test_kg_query(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_query result = tool_kg_query(entity="Max") assert result["count"] > 0 def test_kg_invalidate(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_invalidate result = tool_kg_invalidate( @@ -293,14 +297,14 @@ class TestKGTools: assert result["success"] is True def test_kg_timeline(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_timeline result = tool_kg_timeline(entity="Alice") assert result["count"] > 0 def test_kg_stats(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_stats result = tool_kg_stats() @@ -312,8 +316,9 @@ class TestKGTools: class TestDiaryTools: def test_diary_write_and_read(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_diary_write, tool_diary_read w = tool_diary_write( @@ -330,8 +335,9 @@ class TestDiaryTools: assert "authentication" in r["entries"][0]["content"] def test_diary_read_empty(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_diary_read r = tool_diary_read(agent_name="Nobody") diff --git a/tests/test_miner.py b/tests/test_miner.py index 337e949..efe55a7 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -47,7 +47,7 @@ def test_project_mining(): col = client.get_collection("mempalace_drawers") assert col.count() > 0 finally: - shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) def test_scan_project_respects_gitignore(): diff --git a/tests/test_normalize.py b/tests/test_normalize.py index c304c9d..fc50251 100644 --- a/tests/test_normalize.py +++ b/tests/test_normalize.py @@ -1,31 +1,501 @@ -import os import json -import tempfile -from mempalace.normalize import normalize +from unittest.mock import patch + +from mempalace.normalize import ( + _extract_content, + _messages_to_transcript, + _try_chatgpt_json, + _try_claude_ai_json, + _try_claude_code_jsonl, + _try_codex_jsonl, + _try_normalize_json, + _try_slack_json, + normalize, +) -def test_plain_text(): - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.write("Hello world\nSecond line\n") - f.close() - result = normalize(f.name) +# ── normalize() top-level ────────────────────────────────────────────── + + +def test_plain_text(tmp_path): + f = tmp_path / "plain.txt" + f.write_text("Hello world\nSecond line\n") + result = normalize(str(f)) assert "Hello world" in result - os.unlink(f.name) -def test_claude_json(): +def test_claude_json(tmp_path): data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}] - f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) - json.dump(data, f) - f.close() - result = normalize(f.name) + f = tmp_path / "claude.json" + f.write_text(json.dumps(data)) + result = normalize(str(f)) assert "Hi" in result - os.unlink(f.name) -def test_empty(): - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.close() - result = normalize(f.name) +def test_empty(tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + result = normalize(str(f)) assert result.strip() == "" - os.unlink(f.name) + + +def test_normalize_io_error(): + """normalize raises IOError for unreadable file.""" + try: + normalize("/nonexistent/path/file.txt") + assert False, "Should have raised" + except IOError as e: + assert "Could not read" in str(e) + + +def test_normalize_already_has_markers(tmp_path): + """Files with >= 3 '>' lines pass through unchanged.""" + content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n" + f = tmp_path / "markers.txt" + f.write_text(content) + result = normalize(str(f)) + assert result == content + + +def test_normalize_json_content_detected_by_brace(tmp_path): + """A .txt file starting with [ triggers JSON parsing.""" + data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}] + f = tmp_path / "chat.txt" + f.write_text(json.dumps(data)) + result = normalize(str(f)) + assert "Hey" in result + + +def test_normalize_whitespace_only(tmp_path): + f = tmp_path / "ws.txt" + f.write_text(" \n \n ") + result = normalize(str(f)) + assert result.strip() == "" + + +# ── _extract_content ─────────────────────────────────────────────────── + + +def test_extract_content_string(): + assert _extract_content("hello") == "hello" + + +def test_extract_content_list_of_strings(): + assert _extract_content(["hello", "world"]) == "hello world" + + +def test_extract_content_list_of_blocks(): + blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}] + assert _extract_content(blocks) == "hello" + + +def test_extract_content_dict(): + assert _extract_content({"text": "hello"}) == "hello" + + +def test_extract_content_none(): + assert _extract_content(None) == "" + + +def test_extract_content_mixed_list(): + blocks = ["plain", {"type": "text", "text": "block"}] + assert _extract_content(blocks) == "plain block" + + +# ── _try_claude_code_jsonl ───────────────────────────────────────────── + + +def test_claude_code_jsonl_valid(): + lines = [ + json.dumps({"type": "human", "message": {"content": "What is X?"}}), + json.dumps({"type": "assistant", "message": {"content": "X is Y."}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + assert "> What is X?" in result + assert "X is Y." in result + + +def test_claude_code_jsonl_user_type(): + lines = [ + json.dumps({"type": "user", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + assert "> Q" in result + + +def test_claude_code_jsonl_too_few_messages(): + lines = [json.dumps({"type": "human", "message": {"content": "only one"}})] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is None + + +def test_claude_code_jsonl_invalid_json_lines(): + lines = [ + "not json", + json.dumps({"type": "human", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + + +def test_claude_code_jsonl_non_dict_entries(): + lines = [ + json.dumps([1, 2, 3]), + json.dumps({"type": "human", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + + +# ── _try_codex_jsonl ─────────────────────────────────────────────────── + + +def test_codex_jsonl_valid(): + lines = [ + json.dumps({"type": "session_meta", "payload": {}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + assert "> Q" in result + + +def test_codex_jsonl_no_session_meta(): + """Without session_meta, codex parser returns None.""" + lines = [ + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is None + + +def test_codex_jsonl_skips_non_event_msg(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + assert "X" not in result.split("> Q")[0] + + +def test_codex_jsonl_non_string_message(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +def test_codex_jsonl_empty_text_skipped(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +def test_codex_jsonl_payload_not_dict(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": "not a dict"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +# ── _try_claude_ai_json ─────────────────────────────────────────────── + + +def test_claude_ai_flat_messages(): + data = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + result = _try_claude_ai_json(data) + assert result is not None + assert "> Hello" in result + + +def test_claude_ai_dict_with_messages_key(): + data = { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + } + result = _try_claude_ai_json(data) + assert result is not None + + +def test_claude_ai_privacy_export(): + data = [ + { + "chat_messages": [ + {"role": "human", "content": "Q1"}, + {"role": "ai", "content": "A1"}, + ] + } + ] + result = _try_claude_ai_json(data) + assert result is not None + assert "> Q1" in result + + +def test_claude_ai_not_a_list(): + result = _try_claude_ai_json("not a list") + assert result is None + + +def test_claude_ai_too_few_messages(): + data = [{"role": "user", "content": "Hello"}] + result = _try_claude_ai_json(data) + assert result is None + + +def test_claude_ai_dict_with_chat_messages_key(): + data = { + "chat_messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "World"}, + ] + } + result = _try_claude_ai_json(data) + assert result is not None + + +def test_claude_ai_privacy_export_non_dict_items(): + """Non-dict items in privacy export are skipped.""" + data = [ + { + "chat_messages": [ + "not a dict", + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "A"}, + ] + }, + "not a convo", + ] + result = _try_claude_ai_json(data) + assert result is not None + + +# ── _try_chatgpt_json ───────────────────────────────────────────────── + + +def test_chatgpt_json_valid(): + data = { + "mapping": { + "root": { + "parent": None, + "message": None, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Hello ChatGPT"]}, + }, + "children": ["msg2"], + }, + "msg2": { + "parent": "msg1", + "message": { + "author": {"role": "assistant"}, + "content": {"parts": ["Hello! How can I help?"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is not None + assert "> Hello ChatGPT" in result + + +def test_chatgpt_json_no_mapping(): + result = _try_chatgpt_json({"data": []}) + assert result is None + + +def test_chatgpt_json_not_dict(): + result = _try_chatgpt_json([1, 2, 3]) + assert result is None + + +def test_chatgpt_json_fallback_root(): + """Root node has a message (no synthetic root), uses fallback.""" + data = { + "mapping": { + "root": { + "parent": None, + "message": { + "author": {"role": "system"}, + "content": {"parts": ["system prompt"]}, + }, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Hello"]}, + }, + "children": ["msg2"], + }, + "msg2": { + "parent": "msg1", + "message": { + "author": {"role": "assistant"}, + "content": {"parts": ["Hi there"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is not None + + +def test_chatgpt_json_too_few_messages(): + data = { + "mapping": { + "root": { + "parent": None, + "message": None, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Only one"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is None + + +# ── _try_slack_json ──────────────────────────────────────────────────── + + +def test_slack_json_valid(): + data = [ + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi there"}, + ] + result = _try_slack_json(data) + assert result is not None + assert "Hello" in result + + +def test_slack_json_not_a_list(): + result = _try_slack_json({"type": "message"}) + assert result is None + + +def test_slack_json_too_few_messages(): + data = [{"type": "message", "user": "U1", "text": "Hello"}] + result = _try_slack_json(data) + assert result is None + + +def test_slack_json_skips_non_message_types(): + data = [ + {"type": "channel_join", "user": "U1", "text": "joined"}, + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_three_users(): + """Three speakers get alternating roles.""" + data = [ + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + {"type": "message", "user": "U3", "text": "Hey"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_empty_text_skipped(): + data = [ + {"type": "message", "user": "U1", "text": ""}, + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_username_fallback(): + data = [ + {"type": "message", "username": "bot1", "text": "Hello"}, + {"type": "message", "username": "bot2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +# ── _try_normalize_json ──────────────────────────────────────────────── + + +def test_try_normalize_json_invalid_json(): + result = _try_normalize_json("not json at all {{{") + assert result is None + + +def test_try_normalize_json_valid_but_unknown_schema(): + result = _try_normalize_json(json.dumps({"random": "data"})) + assert result is None + + +# ── _messages_to_transcript ──────────────────────────────────────────── + + +def test_messages_to_transcript_basic(): + msgs = [("user", "Q"), ("assistant", "A")] + with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True): + result = _messages_to_transcript(msgs, spellcheck=False) + assert "> Q" in result + assert "A" in result + + +def test_messages_to_transcript_consecutive_users(): + """Two user messages in a row (no assistant between).""" + msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")] + result = _messages_to_transcript(msgs, spellcheck=False) + assert "> Q1" in result + assert "> Q2" in result + + +def test_messages_to_transcript_assistant_first(): + """Leading assistant message (no user before it).""" + msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")] + result = _messages_to_transcript(msgs, spellcheck=False) + assert "preamble" in result + assert "> Q" in result diff --git a/tests/test_onboarding.py b/tests/test_onboarding.py new file mode 100644 index 0000000..ea7a37b --- /dev/null +++ b/tests/test_onboarding.py @@ -0,0 +1,452 @@ +"""Tests for mempalace.onboarding.""" + +import os +from unittest.mock import patch + +from mempalace.onboarding import ( + DEFAULT_WINGS, + _ask, + _ask_mode, + _ask_people, + _ask_projects, + _ask_wings, + _auto_detect, + _generate_aaak_bootstrap, + _header, + _hr, + _warn_ambiguous, + _yn, + quick_setup, + run_onboarding, +) + +# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars) +os.environ["PYTHONUTF8"] = "1" + + +# ── DEFAULT_WINGS ─────────────────────────────────────────────────────── + + +def test_default_wings_has_expected_keys(): + assert "work" in DEFAULT_WINGS + assert "personal" in DEFAULT_WINGS + assert "combo" in DEFAULT_WINGS + + +def test_default_wings_work_has_projects(): + assert "projects" in DEFAULT_WINGS["work"] + + +def test_default_wings_personal_has_family(): + assert "family" in DEFAULT_WINGS["personal"] + + +def test_default_wings_combo_has_both(): + wings = DEFAULT_WINGS["combo"] + assert "family" in wings + assert "work" in wings + + +def test_default_wings_values_are_lists(): + for mode, wings in DEFAULT_WINGS.items(): + assert isinstance(wings, list), f"{mode} wings should be a list" + assert len(wings) >= 3, f"{mode} should have at least 3 wings" + + +# ── _warn_ambiguous ───────────────────────────────────────────────────── + + +def test_warn_ambiguous_flags_common_words(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "Riley", "relationship": "daughter"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + # Riley is not a common English word + assert "Riley" not in result + + +def test_warn_ambiguous_empty_list(): + result = _warn_ambiguous([]) + assert result == [] + + +def test_warn_ambiguous_no_ambiguous_names(): + people = [ + {"name": "Riley", "relationship": "daughter"}, + {"name": "Devon", "relationship": "friend"}, + ] + result = _warn_ambiguous(people) + assert result == [] + + +def test_warn_ambiguous_multiple_hits(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "May", "relationship": "aunt"}, + {"name": "Joy", "relationship": "sister"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + assert "May" in result + assert "Joy" in result + + +# ── quick_setup ───────────────────────────────────────────────────────── + + +def test_quick_setup_creates_registry(tmp_path): + registry = quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + config_dir=tmp_path, + ) + assert "Riley" in registry.people + assert "MemPalace" in registry.projects + assert registry.mode == "personal" + + +def test_quick_setup_work_mode(tmp_path): + registry = quick_setup( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["Acme"], + config_dir=tmp_path, + ) + assert registry.mode == "work" + assert "Alice" in registry.people + assert "Acme" in registry.projects + + +def test_quick_setup_empty(tmp_path): + registry = quick_setup(mode="personal", people=[], config_dir=tmp_path) + assert len(registry.people) == 0 + assert len(registry.projects) == 0 + + +def test_quick_setup_saves_to_disk(tmp_path): + quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + config_dir=tmp_path, + ) + assert (tmp_path / "entity_registry.json").exists() + + +# ── _generate_aaak_bootstrap ─────────────────────────────────────────── + + +def test_generate_aaak_bootstrap_creates_files(tmp_path): + people = [ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ] + projects = ["MemPalace"] + wings = ["family", "creative"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() + + +def test_generate_aaak_bootstrap_entities_content(tmp_path): + people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}] + projects = ["MemPalace"] + wings = ["family"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + content = (tmp_path / "aaak_entities.md").read_text() + assert "Riley" in content + assert "RIL" in content # entity code + assert "MemPalace" in content + + +def test_generate_aaak_bootstrap_facts_content(tmp_path): + people = [ + {"name": "Alice", "relationship": "colleague", "context": "work"}, + ] + projects = ["Acme"] + wings = ["projects"] + _generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path) + + content = (tmp_path / "critical_facts.md").read_text() + assert "Alice" in content + assert "Acme" in content + assert "work" in content.lower() + + +def test_generate_aaak_bootstrap_empty_people(tmp_path): + _generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path) + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() + + +def test_generate_aaak_bootstrap_collision(tmp_path): + """Two people with same 3-letter code get different codes.""" + people = [ + {"name": "Alice", "relationship": "friend", "context": "work"}, + {"name": "Alison", "relationship": "coworker", "context": "work"}, + ] + _generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path) + content = (tmp_path / "aaak_entities.md").read_text() + assert "ALI" in content + assert "ALIS" in content + + +def test_generate_aaak_bootstrap_no_relationship(tmp_path): + """Person without relationship string still generates entry.""" + people = [{"name": "Bob", "context": "work"}] + _generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path) + content = (tmp_path / "aaak_entities.md").read_text() + assert "BOB=Bob" in content + + +# ── _hr, _header ────────────────────────────────────────────────────── + + +def test_hr_prints_line(capsys): + _hr() + out = capsys.readouterr().out + assert "─" in out + + +def test_header_prints_banner(capsys): + _header("Test Title") + out = capsys.readouterr().out + assert "Test Title" in out + assert "=" in out + + +# ── _ask ────────────────────────────────────────────────────────────── + + +def test_ask_with_default_uses_default(): + with patch("builtins.input", return_value=""): + result = _ask("prompt", default="fallback") + assert result == "fallback" + + +def test_ask_with_default_uses_input(): + with patch("builtins.input", return_value="custom"): + result = _ask("prompt", default="fallback") + assert result == "custom" + + +def test_ask_no_default(): + with patch("builtins.input", return_value="answer"): + result = _ask("prompt") + assert result == "answer" + + +# ── _yn ─────────────────────────────────────────────────────────────── + + +def test_yn_default_yes_empty_input(): + with patch("builtins.input", return_value=""): + assert _yn("continue?") is True + + +def test_yn_default_no_empty_input(): + with patch("builtins.input", return_value=""): + assert _yn("continue?", default="n") is False + + +def test_yn_explicit_yes(): + with patch("builtins.input", return_value="yes"): + assert _yn("continue?", default="n") is True + + +def test_yn_explicit_no(): + with patch("builtins.input", return_value="no"): + assert _yn("continue?") is False + + +# ── _ask_mode ───────────────────────────────────────────────────────── + + +def test_ask_mode_work(): + with patch("builtins.input", return_value="1"): + assert _ask_mode() == "work" + + +def test_ask_mode_personal(): + with patch("builtins.input", return_value="2"): + assert _ask_mode() == "personal" + + +def test_ask_mode_combo(): + with patch("builtins.input", return_value="3"): + assert _ask_mode() == "combo" + + +def test_ask_mode_retries_on_bad_input(): + with patch("builtins.input", side_effect=["x", "bad", "1"]): + assert _ask_mode() == "work" + + +# ── _ask_people ─────────────────────────────────────────────────────── + + +def test_ask_people_personal_mode(): + with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]): + people, aliases = _ask_people("personal") + assert len(people) == 1 + assert people[0]["name"] == "Alice" + assert people[0]["relationship"] == "daughter" + + +def test_ask_people_work_mode(): + with patch("builtins.input", side_effect=["Bob, manager", "", "done"]): + people, aliases = _ask_people("work") + assert len(people) == 1 + assert people[0]["name"] == "Bob" + assert people[0]["context"] == "work" + + +def test_ask_people_combo_mode(): + with patch( + "builtins.input", + side_effect=[ + "Alice, daughter", + "", + "done", # personal + "Bob, boss", + "done", # work + ], + ): + people, aliases = _ask_people("combo") + assert len(people) == 2 + + +def test_ask_people_with_nickname(): + with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]): + people, aliases = _ask_people("personal") + assert aliases == {"Ali": "Alice"} + + +def test_ask_people_empty_name_skipped(): + with patch("builtins.input", side_effect=["", "done"]): + people, aliases = _ask_people("personal") + assert len(people) == 0 + + +# ── _ask_projects ───────────────────────────────────────────────────── + + +def test_ask_projects_personal_returns_empty(): + result = _ask_projects("personal") + assert result == [] + + +def test_ask_projects_work_mode(): + with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]): + result = _ask_projects("work") + assert result == ["Acme", "BigCo"] + + +def test_ask_projects_empty_entry_stops(): + with patch("builtins.input", side_effect=["Acme", ""]): + result = _ask_projects("work") + assert result == ["Acme"] + + +# ── _ask_wings ──────────────────────────────────────────────────────── + + +def test_ask_wings_accept_defaults(): + with patch("builtins.input", return_value=""): + result = _ask_wings("work") + assert result == DEFAULT_WINGS["work"] + + +def test_ask_wings_custom(): + with patch("builtins.input", return_value="alpha, beta, gamma"): + result = _ask_wings("personal") + assert result == ["alpha", "beta", "gamma"] + + +# ── _auto_detect ────────────────────────────────────────────────────── + + +def test_auto_detect_no_files(tmp_path): + result = _auto_detect(str(tmp_path), []) + assert result == [] + + +def test_auto_detect_filters_known(tmp_path): + known = [{"name": "Alice"}] + fake_detected = { + "people": [ + {"name": "Alice", "confidence": 0.9, "signals": ["test"]}, + {"name": "Bob", "confidence": 0.8, "signals": ["test"]}, + ], + "projects": [], + "uncertain": [], + } + with ( + patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]), + patch("mempalace.onboarding.detect_entities", return_value=fake_detected), + ): + result = _auto_detect(str(tmp_path), known) + names = [p["name"] for p in result] + assert "Alice" not in names + assert "Bob" in names + + +def test_auto_detect_filters_low_confidence(tmp_path): + fake_detected = { + "people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}], + "projects": [], + "uncertain": [], + } + with ( + patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]), + patch("mempalace.onboarding.detect_entities", return_value=fake_detected), + ): + result = _auto_detect(str(tmp_path), []) + assert len(result) == 0 + + +def test_auto_detect_handles_exception(tmp_path): + with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")): + result = _auto_detect(str(tmp_path), []) + assert result == [] + + +# ── run_onboarding ──────────────────────────────────────────────────── + + +def test_run_onboarding_basic_flow(tmp_path): + """Test the full onboarding flow with minimal mocking.""" + with ( + patch("mempalace.onboarding._ask_mode", return_value="work"), + patch( + "mempalace.onboarding._ask_people", + return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}), + ), + patch("mempalace.onboarding._ask_projects", return_value=["Acme"]), + patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]), + patch("mempalace.onboarding._yn", return_value=False), + patch("mempalace.onboarding._warn_ambiguous", return_value=[]), + ): + registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False) + assert "Bob" in registry.people + assert "Acme" in registry.projects + + +def test_run_onboarding_with_ambiguous_names(tmp_path): + """Onboarding prints a warning for ambiguous names.""" + with ( + patch("mempalace.onboarding._ask_mode", return_value="personal"), + patch( + "mempalace.onboarding._ask_people", + return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}), + ), + patch("mempalace.onboarding._ask_projects", return_value=[]), + patch("mempalace.onboarding._ask_wings", return_value=["family"]), + patch("mempalace.onboarding._yn", return_value=False), + ): + registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False) + assert "Grace" in registry.people diff --git a/tests/test_palace_graph.py b/tests/test_palace_graph.py new file mode 100644 index 0000000..ddda272 --- /dev/null +++ b/tests/test_palace_graph.py @@ -0,0 +1,244 @@ +"""Tests for mempalace.palace_graph — graph traversal layer. + +All ChromaDB access is mocked — no real database needed. +""" + +from unittest.mock import MagicMock, patch + + +def _make_fake_collection(metadatas, ids=None): + """Create a mock collection that returns the given metadata in batches.""" + if ids is None: + ids = [f"id_{i}" for i in range(len(metadatas))] + + col = MagicMock() + col.count.return_value = len(metadatas) + + def fake_get(limit=1000, offset=0, include=None): + batch_meta = metadatas[offset : offset + limit] + batch_ids = ids[offset : offset + limit] + return {"ids": batch_ids, "metadatas": batch_meta} + + col.get.side_effect = fake_get + return col + + +# Patch chromadb at import time so palace_graph can be imported +with patch.dict("sys.modules", {"chromadb": MagicMock()}): + from mempalace.palace_graph import ( + _fuzzy_match, + build_graph, + find_tunnels, + graph_stats, + traverse, + ) + + +# --- build_graph --- + + +class TestBuildGraph: + def test_empty_collection(self): + col = _make_fake_collection([]) + nodes, edges = build_graph(col=col) + assert nodes == {} + assert edges == [] + + def test_falsy_collection(self): + """When col is explicitly falsy, build_graph returns empty.""" + nodes, edges = build_graph(col=0) + assert nodes == {} + assert edges == [] + + def test_single_wing_no_edges(self): + col = _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"}, + ] + ) + nodes, edges = build_graph(col=col) + assert "auth" in nodes + assert nodes["auth"]["count"] == 2 + assert edges == [] + + def test_multi_wing_creates_edges(self): + col = _make_fake_collection( + [ + { + "room": "chromadb", + "wing": "wing_code", + "hall": "databases", + "date": "2026-01-01", + }, + { + "room": "chromadb", + "wing": "wing_project", + "hall": "databases", + "date": "2026-01-02", + }, + ] + ) + nodes, edges = build_graph(col=col) + assert "chromadb" in nodes + assert len(edges) == 1 + assert edges[0]["wing_a"] == "wing_code" + assert edges[0]["wing_b"] == "wing_project" + assert edges[0]["hall"] == "databases" + + def test_general_room_excluded(self): + col = _make_fake_collection( + [ + {"room": "general", "wing": "wing_code", "hall": "misc", "date": ""}, + ] + ) + nodes, edges = build_graph(col=col) + assert "general" not in nodes + + def test_missing_wing_excluded(self): + col = _make_fake_collection( + [ + {"room": "orphan", "wing": "", "hall": "misc", "date": ""}, + ] + ) + nodes, edges = build_graph(col=col) + assert "orphan" not in nodes + + def test_dates_capped_at_five(self): + col = _make_fake_collection( + [ + {"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"} + for i in range(1, 10) + ] + ) + nodes, _ = build_graph(col=col) + assert len(nodes["busy"]["dates"]) <= 5 + + +# --- traverse --- + + +class TestTraverse: + def _build_col(self): + return _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"}, + ] + ) + + def test_traverse_known_room(self): + col = self._build_col() + result = traverse("auth", col=col) + assert isinstance(result, list) + rooms = [r["room"] for r in result] + assert "auth" in rooms + # login shares wing_code with auth + assert "login" in rooms + + def test_traverse_unknown_room(self): + col = self._build_col() + result = traverse("nonexistent", col=col) + assert isinstance(result, dict) + assert "error" in result + assert "suggestions" in result + + def test_traverse_max_hops(self): + col = self._build_col() + result = traverse("auth", col=col, max_hops=0) + # Only the start room itself at hop 0 + assert len(result) == 1 + assert result[0]["room"] == "auth" + + +# --- find_tunnels --- + + +class TestFindTunnels: + def _build_tunnel_col(self): + return _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) + + def test_find_all_tunnels(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + def test_find_tunnels_with_wing_filter(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", col=col) + assert len(tunnels) == 1 + + def test_find_tunnels_no_match(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_nonexistent", col=col) + assert tunnels == [] + + def test_find_tunnels_both_wings(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + +# --- graph_stats --- + + +class TestGraphStats: + def test_empty_graph(self): + col = _make_fake_collection([]) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 0 + assert stats["tunnel_rooms"] == 0 + assert stats["total_edges"] == 0 + + def test_stats_with_data(self): + col = _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 2 + assert stats["tunnel_rooms"] == 1 + assert stats["total_edges"] == 1 + assert "wing_code" in stats["rooms_per_wing"] + + +# --- _fuzzy_match --- + + +class TestFuzzyMatch: + def test_exact_substring(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("chromadb", nodes) + assert "chromadb-setup" in result + + def test_partial_word_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("auth", nodes) + assert "auth-module" in result + + def test_no_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}} + result = _fuzzy_match("zzzzz", nodes) + assert result == [] + + def test_hyphenated_query(self): + nodes = {"riley-college-apps": {}, "college-prep": {}} + result = _fuzzy_match("riley-college", nodes) + assert "riley-college-apps" in result + + def test_max_results(self): + nodes = {f"room-{i}": {} for i in range(20)} + result = _fuzzy_match("room", nodes, n=3) + assert len(result) <= 3 diff --git a/tests/test_room_detector_local.py b/tests/test_room_detector_local.py new file mode 100644 index 0000000..11963e4 --- /dev/null +++ b/tests/test_room_detector_local.py @@ -0,0 +1,264 @@ +"""Tests for mempalace.room_detector_local.""" + +from unittest.mock import MagicMock, patch + +from mempalace.room_detector_local import ( + FOLDER_ROOM_MAP, + detect_rooms_from_files, + detect_rooms_from_folders, + detect_rooms_local, + get_user_approval, + print_proposed_structure, + save_config, +) + + +# ── FOLDER_ROOM_MAP ──────────────────────────────────────────────────── + + +def test_folder_room_map_has_expected_mappings(): + assert FOLDER_ROOM_MAP["frontend"] == "frontend" + assert FOLDER_ROOM_MAP["backend"] == "backend" + assert FOLDER_ROOM_MAP["docs"] == "documentation" + assert FOLDER_ROOM_MAP["tests"] == "testing" + assert FOLDER_ROOM_MAP["config"] == "configuration" + + +def test_folder_room_map_alternative_names(): + assert FOLDER_ROOM_MAP["front-end"] == "frontend" + assert FOLDER_ROOM_MAP["back-end"] == "backend" + assert FOLDER_ROOM_MAP["server"] == "backend" + assert FOLDER_ROOM_MAP["client"] == "frontend" + assert FOLDER_ROOM_MAP["api"] == "backend" + + +# ── detect_rooms_from_folders ─────────────────────────────────────────── + + +def test_detect_rooms_from_folders_standard_layout(tmp_path): + (tmp_path / "frontend").mkdir() + (tmp_path / "backend").mkdir() + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "frontend" in room_names + assert "backend" in room_names + assert "documentation" in room_names + + +def test_detect_rooms_from_folders_always_has_general(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "general" in room_names + + +def test_detect_rooms_from_folders_empty_dir(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + # Should at least have "general" + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_folders_skips_git(tmp_path): + (tmp_path / ".git").mkdir() + (tmp_path / "node_modules").mkdir() + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert ".git" not in room_names + assert "node_modules" not in room_names + + +def test_detect_rooms_from_folders_nested_dirs(tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "components").mkdir() + (src / "routes").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Nested dirs should be detected at one level deep + assert "frontend" in room_names or "backend" in room_names + + +def test_detect_rooms_from_folders_room_has_description(tmp_path): + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + doc_room = next((r for r in rooms if r["name"] == "documentation"), None) + assert doc_room is not None + assert "description" in doc_room + assert "docs" in doc_room["description"] + + +def test_detect_rooms_from_folders_room_has_keywords(tmp_path): + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + fe_room = next((r for r in rooms if r["name"] == "frontend"), None) + assert fe_room is not None + assert "keywords" in fe_room + assert len(fe_room["keywords"]) > 0 + + +def test_detect_rooms_from_folders_custom_named_dirs(tmp_path): + (tmp_path / "mylib").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Custom dir names that don't match FOLDER_ROOM_MAP get added as-is + assert "mylib" in room_names or "general" in room_names + + +# ── detect_rooms_from_files ───────────────────────────────────────────── + + +def test_detect_rooms_from_files_with_matching_filenames(tmp_path): + # Create files whose names contain room keywords + for name in ["test_auth.py", "test_login.py", "test_api.py"]: + (tmp_path / name).write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "testing" in room_names or "general" in room_names + + +def test_detect_rooms_from_files_empty_dir(tmp_path): + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_files_caps_at_six(tmp_path): + # Create many files with different keywords to hit the cap + for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]: + for i in range(3): + (tmp_path / f"{keyword}_file_{i}.txt").write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) <= 6 + + +# ── save_config ───────────────────────────────────────────────────────── + + +def test_save_config_creates_yaml(tmp_path): + rooms = [ + {"name": "frontend", "description": "UI files", "keywords": ["frontend"]}, + {"name": "backend", "description": "Server files", "keywords": ["backend"]}, + ] + save_config(str(tmp_path), "myproject", rooms) + config_file = tmp_path / "mempalace.yaml" + assert config_file.exists() + content = config_file.read_text() + assert "myproject" in content + assert "frontend" in content + assert "backend" in content + + +def test_save_config_valid_yaml(tmp_path): + import yaml + + rooms = [{"name": "general", "description": "All files", "keywords": []}] + save_config(str(tmp_path), "test_proj", rooms) + config_file = tmp_path / "mempalace.yaml" + data = yaml.safe_load(config_file.read_text()) + assert data["wing"] == "test_proj" + assert len(data["rooms"]) == 1 + assert data["rooms"][0]["name"] == "general" + + +# ── print_proposed_structure ────────────────────────────────────────── + + +def test_print_proposed_structure(capsys): + rooms = [ + {"name": "frontend", "description": "UI files"}, + {"name": "general", "description": "Everything else"}, + ] + print_proposed_structure("myapp", rooms, 42, "folder structure") + out = capsys.readouterr().out + assert "myapp" in out + assert "frontend" in out + assert "42 files" in out + assert "folder structure" in out + + +# ── get_user_approval ───────────────────────────────────────────────── + + +def test_get_user_approval_accept_all(): + rooms = [{"name": "frontend", "description": "UI"}] + with patch("builtins.input", return_value=""): + result = get_user_approval(rooms) + assert result == rooms + + +def test_get_user_approval_edit_remove(): + rooms = [ + {"name": "frontend", "description": "UI"}, + {"name": "backend", "description": "Server"}, + ] + with patch("builtins.input", side_effect=["edit", "1", "n"]): + result = get_user_approval(rooms) + # Room 1 (frontend) removed + assert len(result) == 1 + assert result[0]["name"] == "backend" + + +def test_get_user_approval_add_room(): + rooms = [{"name": "general", "description": "All files"}] + with patch( + "builtins.input", + side_effect=[ + "add", + "custom_room", + "My custom room", + "", + ], + ): + result = get_user_approval(rooms) + names = [r["name"] for r in result] + assert "custom_room" in names + + +# ── detect_rooms_local ──────────────────────────────────────────────── + + +def test_detect_rooms_local_yes_mode(tmp_path): + (tmp_path / "docs").mkdir() + (tmp_path / "docs" / "readme.md").write_text("hello") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["file1.py"] + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + detect_rooms_local(str(tmp_path), yes=True) + assert (tmp_path / "mempalace.yaml").exists() + + +def test_detect_rooms_local_fallback_to_files(tmp_path): + """When folder detection gives only 'general', falls back to file patterns.""" + for i in range(3): + (tmp_path / f"test_file_{i}.py").write_text("content") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["f1", "f2"] + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + detect_rooms_local(str(tmp_path), yes=True) + assert (tmp_path / "mempalace.yaml").exists() + + +def test_detect_rooms_local_missing_dir(): + """Non-existent directory causes sys.exit.""" + import pytest + + with pytest.raises(SystemExit): + detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True) + + +def test_detect_rooms_local_interactive(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.py").write_text("code") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["f1"] + with ( + patch.dict("sys.modules", {"mempalace.miner": mock_miner}), + patch( + "mempalace.room_detector_local.get_user_approval", + return_value=[{"name": "general", "description": "All files", "keywords": []}], + ), + ): + detect_rooms_local(str(tmp_path), yes=False) + assert (tmp_path / "mempalace.yaml").exists() diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 44a05aa..94f22b4 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -1,10 +1,18 @@ """ -test_searcher.py — Tests for the programmatic search_memories API. +test_searcher.py -- Tests for both search() (CLI) and search_memories() (API). -Tests the library-facing search interface (not the CLI print variant). +Uses the real ChromaDB fixtures from conftest.py for integration tests, +plus mock-based tests for error paths. """ -from mempalace.searcher import search_memories +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.searcher import SearchError, search, search_memories + + +# ── search_memories (API) ────────────────────────────────────────────── class TestSearchMemories: @@ -30,8 +38,8 @@ class TestSearchMemories: result = search_memories("code", palace_path, n_results=2) assert len(result["results"]) <= 2 - def test_no_palace_returns_error(self): - result = search_memories("anything", "/nonexistent/path") + def test_no_palace_returns_error(self, tmp_path): + result = search_memories("anything", str(tmp_path / "missing")) assert "error" in result def test_result_fields(self, palace_path, seeded_collection): @@ -43,3 +51,75 @@ class TestSearchMemories: assert "source_file" in hit assert "similarity" in hit assert isinstance(hit["similarity"], float) + + def test_search_memories_query_error(self): + """search_memories returns error dict when query raises.""" + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("query failed") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + result = search_memories("test", "/fake/path") + assert "error" in result + assert "query failed" in result["error"] + + def test_search_memories_filters_in_result(self, palace_path, seeded_collection): + result = search_memories("test", palace_path, wing="project", room="backend") + assert result["filters"]["wing"] == "project" + assert result["filters"]["room"] == "backend" + + +# ── search() (CLI print function) ───────────────────────────────────── + + +class TestSearchCLI: + def test_search_prints_results(self, palace_path, seeded_collection, capsys): + search("JWT authentication", palace_path) + captured = capsys.readouterr() + assert "JWT" in captured.out or "authentication" in captured.out + + def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys): + search("planning", palace_path, wing="notes") + captured = capsys.readouterr() + assert "Results for" in captured.out + + def test_search_with_room_filter(self, palace_path, seeded_collection, capsys): + search("database", palace_path, room="backend") + captured = capsys.readouterr() + assert "Room:" in captured.out + + def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys): + search("code", palace_path, wing="project", room="frontend") + captured = capsys.readouterr() + assert "Wing:" in captured.out + assert "Room:" in captured.out + + def test_search_no_palace_raises(self, tmp_path): + with pytest.raises(SearchError, match="No palace found"): + search("anything", str(tmp_path / "missing")) + + def test_search_no_results(self, palace_path, collection, capsys): + """Empty collection returns no results message.""" + # collection is empty (no seeded data) + result = search("xyzzy_nonexistent_query", palace_path, n_results=1) + captured = capsys.readouterr() + # Either prints "No results" or returns None + assert result is None or "No results" in captured.out + + def test_search_query_error_raises(self): + """search raises SearchError when query fails.""" + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("boom") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + with pytest.raises(SearchError, match="Search error"): + search("test", "/fake/path") + + def test_search_n_results(self, palace_path, seeded_collection, capsys): + search("code", palace_path, n_results=1) + captured = capsys.readouterr() + # Should have output with at least one result block + assert "[1]" in captured.out diff --git a/tests/test_spellcheck.py b/tests/test_spellcheck.py new file mode 100644 index 0000000..f2c7484 --- /dev/null +++ b/tests/test_spellcheck.py @@ -0,0 +1,160 @@ +"""Tests for mempalace.spellcheck — spell-correction utilities.""" + +from unittest.mock import patch + +from mempalace.spellcheck import ( + _edit_distance, + _get_system_words, + _should_skip, + spellcheck_transcript, + spellcheck_transcript_line, + spellcheck_user_text, +) + + +# --- _should_skip --- + + +class TestShouldSkip: + """Token-level skip logic.""" + + def test_short_tokens_skipped(self): + assert _should_skip("hi", set()) is True + assert _should_skip("ok", set()) is True + assert _should_skip("I", set()) is True + + def test_digits_skipped(self): + assert _should_skip("3am", set()) is True + assert _should_skip("top10", set()) is True + assert _should_skip("bge-large-v1.5", set()) is True + + def test_camelcase_skipped(self): + assert _should_skip("ChromaDB", set()) is True + assert _should_skip("MemPalace", set()) is True + + def test_allcaps_skipped(self): + assert _should_skip("NDCG", set()) is True + assert _should_skip("MAX_RESULTS", set()) is True + + def test_technical_skipped(self): + assert _should_skip("bge-large", set()) is True + assert _should_skip("train_test", set()) is True + + def test_url_skipped(self): + assert _should_skip("https://example.com", set()) is True + assert _should_skip("www.google.com", set()) is True + + def test_code_or_emoji_skipped(self): + assert _should_skip("`code`", set()) is True + assert _should_skip("**bold**", set()) is True + + def test_known_name_skipped(self): + assert _should_skip("mempalace", {"mempalace"}) is True + + def test_normal_word_not_skipped(self): + assert _should_skip("hello", set()) is False + assert _should_skip("question", set()) is False + + +# --- _edit_distance --- + + +class TestEditDistance: + def test_identical(self): + assert _edit_distance("hello", "hello") == 0 + + def test_empty_strings(self): + assert _edit_distance("", "abc") == 3 + assert _edit_distance("abc", "") == 3 + assert _edit_distance("", "") == 0 + + def test_single_edit(self): + assert _edit_distance("cat", "bat") == 1 # substitution + assert _edit_distance("cat", "cats") == 1 # insertion + assert _edit_distance("cats", "cat") == 1 # deletion + + def test_known_distance(self): + assert _edit_distance("kitten", "sitting") == 3 + + +# --- _get_system_words --- + + +def test_get_system_words_returns_set(): + result = _get_system_words() + assert isinstance(result, set) + + +# --- spellcheck_user_text --- + + +def test_spellcheck_user_text_passthrough_no_autocorrect(): + """When autocorrect is not installed, text passes through unchanged.""" + with patch("mempalace.spellcheck._get_speller", return_value=None): + text = "somee misspeledd textt" + assert spellcheck_user_text(text) == text + + +def test_spellcheck_user_text_with_speller(): + """When a speller is available, it corrects words.""" + + def fake_speller(word): + corrections = {"knoe": "know", "befor": "before"} + return corrections.get(word, word) + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("knoe the question befor") + assert "know" in result + assert "before" in result + + +def test_spellcheck_preserves_technical_terms(): + """Technical terms should never be touched even with a speller.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + result = spellcheck_user_text("ChromaDB bge-large", known_names=set()) + assert "ChromaDB" in result + assert "bge-large" in result + assert "WRONG" not in result + + +# --- spellcheck_transcript_line --- + + +def test_transcript_line_user_turn(): + """Lines starting with '>' should be processed.""" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"): + result = spellcheck_transcript_line("> hello world") + assert "corrected" in result + + +def test_transcript_line_assistant_turn(): + """Lines not starting with '>' should pass through unchanged.""" + line = "This is an assistant response" + assert spellcheck_transcript_line(line) == line + + +def test_transcript_line_empty_user_turn(): + """A '> ' line with no message content should pass through.""" + line = "> " + assert spellcheck_transcript_line(line) == line + + +# --- spellcheck_transcript --- + + +def test_spellcheck_transcript_processes_content(): + """Full transcript: only '>' lines are touched.""" + content = "Assistant line\n> user line\nAnother assistant line" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"): + result = spellcheck_transcript(content) + lines = result.split("\n") + assert lines[0] == "Assistant line" + assert "fixed" in lines[1] + assert lines[2] == "Another assistant line" diff --git a/tests/test_spellcheck_extra.py b/tests/test_spellcheck_extra.py new file mode 100644 index 0000000..567cb01 --- /dev/null +++ b/tests/test_spellcheck_extra.py @@ -0,0 +1,72 @@ +"""Extra spellcheck tests covering _load_known_names and speller edge cases.""" + +from unittest.mock import patch, MagicMock + +from mempalace.spellcheck import ( + _load_known_names, + spellcheck_user_text, +) + + +class TestLoadKnownNames: + def test_returns_names_from_registry(self): + mock_reg = MagicMock() + mock_reg._data = { + "entities": { + "e1": {"canonical": "Alice", "aliases": ["ali"]}, + "e2": {"canonical": "Bob", "aliases": []}, + } + } + with patch("mempalace.entity_registry.EntityRegistry") as MockER: + MockER.load.return_value = mock_reg + names = _load_known_names() + assert "alice" in names + assert "ali" in names + assert "bob" in names + + def test_returns_empty_on_exception(self): + with patch( + "mempalace.entity_registry.EntityRegistry.load", + side_effect=Exception("no registry"), + ): + names = _load_known_names() + assert names == set() + + +class TestSpellerEdgeCases: + def test_capitalized_word_skipped(self): + """Capitalized words (likely proper nouns) are not corrected.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("Alice went home") + assert "Alice" in result + assert "WRONG" not in result + + def test_system_word_not_corrected(self): + """Words in system dict should not be corrected.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("coherently") + assert "coherently" in result + + def test_high_edit_distance_rejected(self): + """Corrections with too many edits are rejected.""" + + def fake_speller(word): + return "completely_different_word" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("hello") + assert "hello" in result diff --git a/tests/test_split_mega_files.py b/tests/test_split_mega_files.py index 70c7f84..c1db02b 100644 --- a/tests/test_split_mega_files.py +++ b/tests/test_split_mega_files.py @@ -3,6 +3,9 @@ import json from mempalace import split_mega_files as smf +# ── Config loading ───────────────────────────────────────────────────── + + def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path): monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json") smf._KNOWN_NAMES_CACHE = None @@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch): monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"]) people = smf.extract_people(["> Alice reviewed the change with Ben\n"]) assert people == ["Alice", "Ben"] + + +# ── Config: force_reload and invalid JSON ────────────────────────────── + + +def test_load_known_names_force_reload(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text(json.dumps(["Alice"])) + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + smf._load_known_names_config() + assert smf._KNOWN_NAMES_CACHE == ["Alice"] + + config_path.write_text(json.dumps(["Bob"])) + smf._load_known_names_config(force_reload=True) + assert smf._KNOWN_NAMES_CACHE == ["Bob"] + + +def test_load_known_names_invalid_json(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text("not json {{{") + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + result = smf._load_known_names_config() + assert result is None + + +def test_load_known_names_caching(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text(json.dumps(["Alice"])) + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + smf._load_known_names_config() + # Second call returns cached value without re-reading + config_path.write_text(json.dumps(["Changed"])) + result = smf._load_known_names_config() + assert result == ["Alice"] + + +# ── is_true_session_start ────────────────────────────────────────────── + + +def test_is_true_session_start_yes(): + lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""] + assert smf.is_true_session_start(lines, 0) is True + + +def test_is_true_session_start_no_ctrl_e(): + lines = [ + "Claude Code v1.0", + "Ctrl+E to show 5 previous messages", + "", + "", + "", + "", + ] + assert smf.is_true_session_start(lines, 0) is False + + +def test_is_true_session_start_no_previous_messages(): + lines = [ + "Claude Code v1.0", + "Some text", + "previous messages here", + "", + "", + "", + ] + assert smf.is_true_session_start(lines, 0) is False + + +# ── find_session_boundaries ──────────────────────────────────────────── + + +def test_find_session_boundaries_two_sessions(): + lines = [ + "Claude Code v1.0", + "content 1", + "", + "", + "", + "", + "", + "Claude Code v1.0", + "content 2", + "", + "", + "", + "", + "", + ] + boundaries = smf.find_session_boundaries(lines) + assert boundaries == [0, 7] + + +def test_find_session_boundaries_none(): + lines = ["Just some text", "No sessions here"] + assert smf.find_session_boundaries(lines) == [] + + +def test_find_session_boundaries_context_restore_skipped(): + lines = [ + "Claude Code v1.0", + "content", + "", + "", + "", + "", + "", + "Claude Code v1.0", + "Ctrl+E to show 5 previous messages", + "", + "", + "", + "", + ] + boundaries = smf.find_session_boundaries(lines) + assert len(boundaries) == 1 + + +# ── extract_timestamp ────────────────────────────────────────────────── + + +def test_extract_timestamp_found(): + lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"] + human, iso = smf.extract_timestamp(lines) + assert human == "2026-03-25_230PM" + assert iso == "2026-03-25" + + +def test_extract_timestamp_not_found(): + lines = ["No timestamp here"] + human, iso = smf.extract_timestamp(lines) + assert human is None + assert iso is None + + +def test_extract_timestamp_only_checks_first_50(): + lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"] + human, iso = smf.extract_timestamp(lines) + assert human is None + + +# ── extract_subject ──────────────────────────────────────────────────── + + +def test_extract_subject_found(): + lines = ["> How do we handle authentication?"] + subject = smf.extract_subject(lines) + assert "authentication" in subject.lower() + + +def test_extract_subject_skips_commands(): + lines = ["> cd /some/dir", "> git status", "> What is the plan?"] + subject = smf.extract_subject(lines) + assert "plan" in subject.lower() + + +def test_extract_subject_fallback(): + lines = ["No prompts at all", "Just text"] + subject = smf.extract_subject(lines) + assert subject == "session" + + +def test_extract_subject_short_prompt_skipped(): + lines = ["> ok", "> yes", "> What about the deployment strategy?"] + subject = smf.extract_subject(lines) + assert "deployment" in subject.lower() + + +def test_extract_subject_truncated(): + lines = ["> " + "a" * 100] + subject = smf.extract_subject(lines) + assert len(subject) <= 60 + + +# ── split_file ───────────────────────────────────────────────────────── + + +def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15): + """Create a mega-file with N sessions.""" + content = "" + for i in range(n_sessions): + content += f"Claude Code v1.{i}\n" + content += f"> What about topic {i} and how it works?\n" + for j in range(lines_per_session - 2): + content += f"Line {j} of session {i}\n" + path = tmp_path / "mega.txt" + path.write_text(content) + return path + + +def test_split_file_creates_output(tmp_path): + mega = _make_mega_file(tmp_path) + out_dir = tmp_path / "output" + out_dir.mkdir() + written = smf.split_file(str(mega), str(out_dir)) + assert len(written) >= 2 + for p in written: + assert p.exists() + + +def test_split_file_dry_run(tmp_path): + mega = _make_mega_file(tmp_path) + out_dir = tmp_path / "output" + out_dir.mkdir() + written = smf.split_file(str(mega), str(out_dir), dry_run=True) + assert len(written) >= 2 + for p in written: + assert not p.exists() + + +def test_split_file_not_mega(tmp_path): + """File with fewer than 2 sessions is not split.""" + path = tmp_path / "single.txt" + path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20) + written = smf.split_file(str(path), str(tmp_path)) + assert written == [] + + +def test_split_file_output_dir_none(tmp_path): + """When output_dir is None, writes to same dir as source.""" + mega = _make_mega_file(tmp_path) + written = smf.split_file(str(mega), None) + assert len(written) >= 2 + for p in written: + assert str(p.parent) == str(tmp_path) + + +def test_split_file_tiny_fragments_skipped(tmp_path): + """Tiny chunks (< 10 lines) are skipped.""" + content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20 + path = tmp_path / "tiny.txt" + path.write_text(content) + written = smf.split_file(str(path), str(tmp_path)) + # The first chunk is very small, should be skipped + for p in written: + assert p.stat().st_size > 0