From c7bd2cd8e4fa6df88e44bbca33fdf7437126e9dc Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:46:31 -0300 Subject: [PATCH 1/8] feat(convo): parse Claude Code conversation dirs into project entities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Claude Code stores sessions under `~/.claude/projects//.jsonl` where `` is the original CWD with `/` replaced by `-`. That encoding is lossy — can't distinguish `foo-bar` (one segment) from `foo/bar` (two) — so slug-decoding alone produces wrong names for any hyphenated project. Fortunately, every message record carries a `cwd` field with the true path. This scanner reads one record per session to recover the accurate project name deterministically, falling back to slug-decoding only if the JSONL is malformed or empty. Output shape matches project_scanner.ProjectInfo so the discover orchestrator can union results across sources. Session count doubles as a density signal for ranking. 22 unit tests cover: root detection, cwd extraction with malformed input tolerance, fallback slug decoding, name resolution using the newest session (so renames win), and dedup when two encoded dirs resolve to the same project. --- mempalace/convo_scanner.py | 152 +++++++++++++++++++++++++++ tests/test_convo_scanner.py | 199 ++++++++++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+) create mode 100644 mempalace/convo_scanner.py create mode 100644 tests/test_convo_scanner.py diff --git a/mempalace/convo_scanner.py b/mempalace/convo_scanner.py new file mode 100644 index 0000000..bb8fbef --- /dev/null +++ b/mempalace/convo_scanner.py @@ -0,0 +1,152 @@ +""" +convo_scanner.py — Parse Claude Code conversation directories into ProjectInfo. + +Claude Code stores sessions under ``~/.claude/projects//.jsonl``, +where the ```` is the original CWD with ``/`` replaced by ``-``. That +encoding is lossy: we can't tell whether ``foo-bar`` in a slug is the +literal project name ``foo-bar`` or two path segments ``foo/bar``. + +Fortunately, every message record in the JSONL carries a ``cwd`` field with +the true path. This scanner reads one record per session to recover the +accurate project name, falling back to slug-decoding only if the JSONL +is malformed or empty. + +Output is the same ``ProjectInfo`` shape used by ``project_scanner``, so the +``discover_entities`` orchestrator can mix-and-match sources. + +Public: + is_claude_projects_root(path) -> bool + scan_claude_projects(path) -> list[ProjectInfo] +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Optional + +from mempalace.project_scanner import ProjectInfo + + +MAX_HEADER_LINES = 20 # lines to read per session looking for `cwd` + + +def is_claude_projects_root(path: Path) -> bool: + """Return True if path looks like `.claude/projects/`. + + Heuristic: at least one child dir whose name starts with ``-`` and which + contains at least one ``.jsonl`` file. + """ + if not path.is_dir(): + return False + try: + children = list(path.iterdir()) + except OSError: + return False + for child in children: + if not (child.is_dir() and child.name.startswith("-")): + continue + try: + if any(p.suffix == ".jsonl" for p in child.iterdir() if p.is_file()): + return True + except OSError: + continue + return False + + +def _extract_cwd_from_session(session_file: Path) -> Optional[str]: + """Return the ``cwd`` from the first message record that carries one. + + Returns None if the file can't be read, has no JSON, or no record has cwd. + """ + try: + with open(session_file, encoding="utf-8", errors="replace") as f: + for i, line in enumerate(f): + if i >= MAX_HEADER_LINES: + break + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + cwd = obj.get("cwd") + if isinstance(cwd, str) and cwd: + return cwd + except OSError: + return None + return None + + +def _decode_slug_fallback(slug: str) -> str: + """Best-effort project name from slug when cwd is unavailable. + + The slug is lossy (`/` and `-` both become `-`). Last non-empty segment + is the closest guess at the project name, preserving kebab-case is + impossible without cwd. + """ + stripped = slug.lstrip("-") + parts = [p for p in stripped.split("-") if p] + return parts[-1] if parts else slug + + +def _resolve_project_name(project_dir: Path) -> str: + """Read one session's cwd to recover the original project name. + + Falls back to slug-decoding if no session has a readable cwd. + """ + sessions = sorted( + (p for p in project_dir.iterdir() if p.is_file() and p.suffix == ".jsonl"), + key=lambda p: p.stat().st_mtime, + reverse=True, # newest first — most likely to be well-formed + ) + for session in sessions: + cwd = _extract_cwd_from_session(session) + if cwd: + return Path(cwd).name or cwd + return _decode_slug_fallback(project_dir.name) + + +def scan_claude_projects(path: str | Path) -> list[ProjectInfo]: + """Scan a ``.claude/projects/`` directory for Claude Code conversations. + + One ProjectInfo per subdir. ``has_git`` is False (the directory isn't a + repo itself) but ``total_commits`` is repurposed here as session count so + the UX surfaces a density signal for ranking. + """ + root = Path(path).expanduser().resolve() + if not is_claude_projects_root(root): + return [] + + projects: dict[str, ProjectInfo] = {} + for sub in sorted(root.iterdir()): + if not (sub.is_dir() and sub.name.startswith("-")): + continue + try: + sessions = [p for p in sub.iterdir() if p.is_file() and p.suffix == ".jsonl"] + except OSError: + continue + if not sessions: + continue + + name = _resolve_project_name(sub) + session_count = len(sessions) + + proj = ProjectInfo( + name=name, + repo_root=sub, + manifest=None, + has_git=False, + total_commits=session_count, + user_commits=session_count, + is_mine=True, # Claude Code sessions are authored by the user + ) + existing = projects.get(name) + if existing is None or session_count > existing.user_commits: + projects[name] = proj + + return sorted( + projects.values(), + key=lambda p: (-p.user_commits, p.name), + ) diff --git a/tests/test_convo_scanner.py b/tests/test_convo_scanner.py new file mode 100644 index 0000000..9fcd339 --- /dev/null +++ b/tests/test_convo_scanner.py @@ -0,0 +1,199 @@ +"""Tests for mempalace.convo_scanner.""" + +import json + +from mempalace.convo_scanner import ( + _decode_slug_fallback, + _extract_cwd_from_session, + _resolve_project_name, + is_claude_projects_root, + scan_claude_projects, +) + + +# ── is_claude_projects_root ───────────────────────────────────────────── + + +def test_is_claude_projects_root_true(tmp_path): + project_dir = tmp_path / "-home-user-dev-foo" + project_dir.mkdir() + (project_dir / "abc.jsonl").write_text("{}\n") + assert is_claude_projects_root(tmp_path) + + +def test_is_claude_projects_root_false_no_dash_prefix(tmp_path): + project_dir = tmp_path / "normal-folder" + project_dir.mkdir() + (project_dir / "abc.jsonl").write_text("{}\n") + assert not is_claude_projects_root(tmp_path) + + +def test_is_claude_projects_root_false_no_jsonl(tmp_path): + project_dir = tmp_path / "-home-user-foo" + project_dir.mkdir() + (project_dir / "other.txt").write_text("hello") + assert not is_claude_projects_root(tmp_path) + + +def test_is_claude_projects_root_false_empty(tmp_path): + assert not is_claude_projects_root(tmp_path) + + +def test_is_claude_projects_root_false_nonexistent(tmp_path): + assert not is_claude_projects_root(tmp_path / "does-not-exist") + + +# ── cwd extraction ────────────────────────────────────────────────────── + + +def test_extract_cwd_from_session(tmp_path): + f = tmp_path / "session.jsonl" + lines = [ + json.dumps({"type": "file-history-snapshot", "messageId": "x"}), + json.dumps({"type": "user", "cwd": "/home/user/dev/myproj", "content": "hi"}), + ] + f.write_text("\n".join(lines) + "\n") + assert _extract_cwd_from_session(f) == "/home/user/dev/myproj" + + +def test_extract_cwd_from_session_skips_malformed(tmp_path): + f = tmp_path / "session.jsonl" + f.write_text( + "{not valid json\n" + json.dumps({"type": "user", "cwd": "/home/user/dev/good"}) + "\n" + ) + assert _extract_cwd_from_session(f) == "/home/user/dev/good" + + +def test_extract_cwd_from_session_none_if_absent(tmp_path): + f = tmp_path / "session.jsonl" + f.write_text(json.dumps({"type": "x", "messageId": "y"}) + "\n") + assert _extract_cwd_from_session(f) is None + + +def test_extract_cwd_from_session_none_if_file_missing(tmp_path): + assert _extract_cwd_from_session(tmp_path / "missing.jsonl") is None + + +# ── slug fallback ─────────────────────────────────────────────────────── + + +def test_decode_slug_fallback_last_segment(): + assert _decode_slug_fallback("-home-user-dev-foo") == "foo" + + +def test_decode_slug_fallback_double_dash(): + assert _decode_slug_fallback("-home-user--bentokit") == "bentokit" + + +def test_decode_slug_fallback_empty(): + assert _decode_slug_fallback("") == "" + + +def test_decode_slug_fallback_only_dashes(): + assert _decode_slug_fallback("---") == "---" + + +# ── _resolve_project_name ─────────────────────────────────────────────── + + +def test_resolve_project_name_uses_cwd(tmp_path): + pdir = tmp_path / "-home-user-dev-coolproj" + pdir.mkdir() + session = pdir / "a.jsonl" + session.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/cool-proj-real"}) + "\n") + assert _resolve_project_name(pdir) == "cool-proj-real" + + +def test_resolve_project_name_falls_back_when_no_cwd(tmp_path): + pdir = tmp_path / "-home-user-dev-foo" + pdir.mkdir() + (pdir / "a.jsonl").write_text(json.dumps({"type": "x"}) + "\n") + assert _resolve_project_name(pdir) == "foo" + + +def test_resolve_project_name_prefers_newer_session(tmp_path): + """Newest session's cwd wins — covers the case where user renamed the + project directory between sessions.""" + + pdir = tmp_path / "-home-user-dev-old" + pdir.mkdir() + old = pdir / "old.jsonl" + old.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/old"}) + "\n") + # Ensure distinguishable mtimes + old_mtime = old.stat().st_mtime - 100 + import os + + os.utime(old, (old_mtime, old_mtime)) + + new = pdir / "new.jsonl" + new.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/new-name"}) + "\n") + assert _resolve_project_name(pdir) == "new-name" + + +# ── scan_claude_projects ──────────────────────────────────────────────── + + +def test_scan_claude_projects_empty_dir(tmp_path): + assert scan_claude_projects(tmp_path) == [] + + +def test_scan_claude_projects_not_a_projects_root(tmp_path): + """Returns empty list if the dir doesn't look like .claude/projects/.""" + (tmp_path / "some-folder").mkdir() + (tmp_path / "some-folder" / "readme.md").write_text("hi") + assert scan_claude_projects(tmp_path) == [] + + +def test_scan_claude_projects_finds_projects(tmp_path): + p1 = tmp_path / "-home-user-dev-alpha" + p1.mkdir() + (p1 / "a.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n") + (p1 / "b.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n") + + p2 = tmp_path / "-home-user-dev-beta" + p2.mkdir() + (p2 / "x.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/beta"}) + "\n") + + result = scan_claude_projects(tmp_path) + names = [p.name for p in result] + assert "alpha" in names + assert "beta" in names + # alpha has 2 sessions, beta has 1 — alpha ranks higher + alpha = next(p for p in result if p.name == "alpha") + beta = next(p for p in result if p.name == "beta") + assert alpha.user_commits == 2 + assert beta.user_commits == 1 + + +def test_scan_claude_projects_ignores_dirs_without_jsonl(tmp_path): + empty_proj = tmp_path / "-home-user-dev-empty" + empty_proj.mkdir() + (empty_proj / "notes.md").write_text("hi") + assert scan_claude_projects(tmp_path) == [] + + +def test_scan_claude_projects_marks_as_mine(tmp_path): + p = tmp_path / "-home-user-dev-owned" + p.mkdir() + (p / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/owned"}) + "\n") + result = scan_claude_projects(tmp_path) + assert len(result) == 1 + assert result[0].is_mine is True + + +def test_scan_claude_projects_dedup_by_name(tmp_path): + """Two encoded dirs resolving to the same project name collapse to one.""" + p1 = tmp_path / "-home-user-a-proj" + p1.mkdir() + (p1 / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n") + (p1 / "t.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n") + + p2 = tmp_path / "-home-user-b-proj" + p2.mkdir() + (p2 / "u.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/b/proj"}) + "\n") + + result = scan_claude_projects(tmp_path) + # Both decode to "proj"; only one remains — the one with more sessions wins + assert len(result) == 1 + assert result[0].name == "proj" + assert result[0].user_commits == 2 From df6c7d0dc3d805f88ed4ffc1e897ab66ddaa4134 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:46:43 -0300 Subject: [PATCH 2/8] feat(llm): pluggable provider abstraction for entity refinement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three providers cover the useful space while keeping the zero-API default: - `ollama` (default): local models via http://localhost:11434. Works fully offline. Tag-matching check accepts both `model` and `model:latest` forms. - `openai-compat`: any /v1/chat/completions endpoint. Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Together, Fireworks, and most self-hosted frameworks. API key falls back to $OPENAI_API_KEY. Endpoint normalization is forgiving about trailing `/v1`. - `anthropic`: Messages API v2023-06-01. API key falls back to $ANTHROPIC_API_KEY. Concatenates multi-block text responses. JSON mode is normalized across providers — Ollama uses `format: "json"`, OpenAI-compat uses `response_format`, Anthropic uses prompt-level instruction. Callers request JSON once; this module handles the provider-specific plumbing. No external SDK dependency; stdlib `urllib` throughout. HTTP errors are wrapped into a single `LLMError` class so callers don't need to distinguish transport, auth, and parse failures at the call site. 26 tests, all with mocked HTTP — suite runs offline with no real provider required. --- mempalace/llm_client.py | 305 ++++++++++++++++++++++++++++++++++++ tests/test_llm_client.py | 327 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 632 insertions(+) create mode 100644 mempalace/llm_client.py create mode 100644 tests/test_llm_client.py diff --git a/mempalace/llm_client.py b/mempalace/llm_client.py new file mode 100644 index 0000000..442cf31 --- /dev/null +++ b/mempalace/llm_client.py @@ -0,0 +1,305 @@ +""" +llm_client.py — Minimal provider abstraction for LLM-assisted entity refinement. + +Three providers cover the useful space: + +- ``ollama`` (default): local models via http://localhost:11434. Works fully + offline. Honors MemPalace's "zero-API required" principle. +- ``openai-compat``: any OpenAI-compatible ``/v1/chat/completions`` endpoint. + Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Fireworks, + Together, and most self-hosted setups. +- ``anthropic``: the official Messages API. Opt-in for users who want Haiku + quality without setting up a local model. + +All providers expose the same ``classify(system, user, json_mode)`` method and +the same ``check_available()`` probe. No external SDK dependencies — stdlib +``urllib`` only. + +JSON mode matters here: we always ask for structured output. Providers +differ on how to request it (Ollama: ``format: json``; OpenAI-compat: +``response_format``; Anthropic: prompt-level instruction) and this module +normalizes that away from the caller. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from typing import Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + + +class LLMError(RuntimeError): + """Raised for any provider failure — transport, parse, auth, missing model.""" + + +@dataclass +class LLMResponse: + text: str + model: str + provider: str + raw: dict + + +# ==================== BASE ==================== + + +class LLMProvider: + name: str = "base" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, + ): + self.model = model + self.endpoint = endpoint + self.api_key = api_key + self.timeout = timeout + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + raise NotImplementedError + + def check_available(self) -> tuple[bool, str]: + """Return ``(ok, message)``. Fast probe that the provider is reachable.""" + raise NotImplementedError + + +def _http_post_json(url: str, body: dict, headers: dict, timeout: int) -> dict: + """POST JSON and return the parsed response. Raises LLMError on any failure.""" + req = Request( + url, + data=json.dumps(body).encode("utf-8"), + headers={"Content-Type": "application/json", **headers}, + ) + try: + with urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read()) + except HTTPError as e: + detail = "" + try: + detail = e.read().decode("utf-8", errors="replace")[:500] + except Exception: + pass + raise LLMError(f"HTTP {e.code} from {url}: {detail or e.reason}") from e + except (URLError, OSError) as e: + raise LLMError(f"Cannot reach {url}: {e}") from e + except json.JSONDecodeError as e: + raise LLMError(f"Malformed response from {url}: {e}") from e + + +# ==================== OLLAMA ==================== + + +class OllamaProvider(LLMProvider): + name = "ollama" + DEFAULT_ENDPOINT = "http://localhost:11434" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + timeout: int = 180, + **_: object, + ): + super().__init__( + model=model, + endpoint=endpoint or self.DEFAULT_ENDPOINT, + timeout=timeout, + ) + + def check_available(self) -> tuple[bool, str]: + try: + with urlopen(f"{self.endpoint}/api/tags", timeout=5) as resp: + data = json.loads(resp.read()) + except (URLError, HTTPError, OSError, json.JSONDecodeError) as e: + return False, f"Cannot reach Ollama at {self.endpoint}: {e}" + names = {m.get("name", "") for m in data.get("models", []) or []} + # Ollama tags may or may not include ':latest' — accept either form + wanted = {self.model, f"{self.model}:latest"} + if not names & wanted: + return ( + False, + f"Model '{self.model}' not loaded in Ollama. " f"Run: ollama pull {self.model}", + ) + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + body: dict = { + "model": self.model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "stream": False, + "options": {"temperature": 0.1}, + } + if json_mode: + body["format"] = "json" + data = _http_post_json(f"{self.endpoint}/api/chat", body, headers={}, timeout=self.timeout) + text = (data.get("message") or {}).get("content", "") + if not text: + raise LLMError(f"Empty response from Ollama (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== OPENAI-COMPAT ==================== + + +class OpenAICompatProvider(LLMProvider): + """Any OpenAI-compatible ``/v1/chat/completions`` endpoint. + + Supply ``--llm-endpoint http://host:port`` (with or without ``/v1``). + API key via ``--llm-api-key`` or the ``OPENAI_API_KEY`` env var. + """ + + name = "openai-compat" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, + **_: object, + ): + resolved_key = api_key or os.environ.get("OPENAI_API_KEY") + super().__init__(model=model, endpoint=endpoint, api_key=resolved_key, timeout=timeout) + + def _resolve_url(self) -> str: + if not self.endpoint: + raise LLMError("openai-compat provider requires --llm-endpoint") + url = self.endpoint.rstrip("/") + if url.endswith("/chat/completions"): + return url + if not url.endswith("/v1"): + url = f"{url}/v1" + return f"{url}/chat/completions" + + def check_available(self) -> tuple[bool, str]: + if not self.endpoint: + return False, "no --llm-endpoint configured" + base = self.endpoint.rstrip("/") + base = base.removesuffix("/chat/completions").removesuffix("/v1") + try: + req = Request(f"{base}/v1/models") + if self.api_key: + req.add_header("Authorization", f"Bearer {self.api_key}") + with urlopen(req, timeout=5): + pass + except (URLError, HTTPError, OSError) as e: + return False, f"Cannot reach {self.endpoint}: {e}" + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + body: dict = { + "model": self.model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "temperature": 0.1, + } + if json_mode: + body["response_format"] = {"type": "json_object"} + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + data = _http_post_json(self._resolve_url(), body, headers=headers, timeout=self.timeout) + try: + text = data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as e: + raise LLMError(f"Unexpected response shape: {e}") from e + if not text: + raise LLMError(f"Empty response from {self.name} (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== ANTHROPIC ==================== + + +class AnthropicProvider(LLMProvider): + name = "anthropic" + DEFAULT_ENDPOINT = "https://api.anthropic.com" + API_VERSION = "2023-06-01" + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, + timeout: int = 120, + **_: object, + ): + key = api_key or os.environ.get("ANTHROPIC_API_KEY") + super().__init__( + model=model, + endpoint=endpoint or self.DEFAULT_ENDPOINT, + api_key=key, + timeout=timeout, + ) + + def check_available(self) -> tuple[bool, str]: + if not self.api_key: + return False, "ANTHROPIC_API_KEY not set (use --llm-api-key or env)" + # Don't probe — a live request would cost money. First real call will + # surface auth errors if the key is invalid. + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + if not self.api_key: + raise LLMError("Anthropic provider requires ANTHROPIC_API_KEY env or --llm-api-key") + sys_prompt = system + if json_mode: + sys_prompt += "\n\nRespond with valid JSON only, no prose." + body = { + "model": self.model, + "max_tokens": 2048, + "temperature": 0.1, + "system": sys_prompt, + "messages": [{"role": "user", "content": user}], + } + headers = { + "X-API-Key": self.api_key, + "anthropic-version": self.API_VERSION, + } + data = _http_post_json( + f"{self.endpoint}/v1/messages", body, headers=headers, timeout=self.timeout + ) + try: + text = "".join( + b.get("text", "") for b in data.get("content", []) or [] if b.get("type") == "text" + ) + except (AttributeError, TypeError) as e: + raise LLMError(f"Unexpected response shape: {e}") from e + if not text: + raise LLMError(f"Empty response from Anthropic (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== FACTORY ==================== + + +PROVIDERS: dict[str, type[LLMProvider]] = { + "ollama": OllamaProvider, + "openai-compat": OpenAICompatProvider, + "anthropic": AnthropicProvider, +} + + +def get_provider( + name: str, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, +) -> LLMProvider: + """Build a provider by name. Raises LLMError on unknown provider.""" + cls = PROVIDERS.get(name) + if cls is None: + raise LLMError(f"Unknown provider '{name}'. Choices: {sorted(PROVIDERS.keys())}") + return cls(model=model, endpoint=endpoint, api_key=api_key, timeout=timeout) diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 0000000..184d100 --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,327 @@ +"""Tests for mempalace.llm_client. + +HTTP is mocked throughout — these tests do not require a running Ollama +or network access. Live-provider smoke tests live outside the unit-test +suite. +""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from mempalace.llm_client import ( + AnthropicProvider, + LLMError, + OllamaProvider, + OpenAICompatProvider, + _http_post_json, + get_provider, +) + + +# ── factory ───────────────────────────────────────────────────────────── + + +def test_get_provider_ollama(): + p = get_provider("ollama", "gemma4:e4b") + assert isinstance(p, OllamaProvider) + assert p.model == "gemma4:e4b" + assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT + + +def test_get_provider_openai_compat(): + p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234") + assert isinstance(p, OpenAICompatProvider) + + +def test_get_provider_anthropic(): + p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx") + assert isinstance(p, AnthropicProvider) + assert p.api_key == "sk-xxx" + + +def test_get_provider_unknown_raises(): + with pytest.raises(LLMError, match="Unknown provider"): + get_provider("nonsense", "x") + + +# ── _http_post_json ───────────────────────────────────────────────────── + + +def test_http_post_json_success(): + mock_resp = MagicMock() + mock_resp.read.return_value = b'{"ok": true}' + mock_resp.__enter__.return_value = mock_resp + mock_resp.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock_resp): + result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5) + assert result == {"ok": True} + + +def test_http_post_json_http_error_wraps_as_llm_error(): + from urllib.error import HTTPError + import io + + err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing")) + with patch("mempalace.llm_client.urlopen", side_effect=err): + with pytest.raises(LLMError, match="HTTP 404"): + _http_post_json("http://x", {}, {}, timeout=5) + + +def test_http_post_json_url_error_wraps_as_llm_error(): + from urllib.error import URLError + + with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")): + with pytest.raises(LLMError, match="Cannot reach"): + _http_post_json("http://x", {}, {}, timeout=5) + + +def test_http_post_json_malformed_response(): + mock_resp = MagicMock() + mock_resp.read.return_value = b"not json" + mock_resp.__enter__.return_value = mock_resp + mock_resp.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock_resp): + with pytest.raises(LLMError, match="Malformed"): + _http_post_json("http://x", {}, {}, timeout=5) + + +# ── OllamaProvider ────────────────────────────────────────────────────── + + +def _mock_ollama_chat_response(content: str): + mock = MagicMock() + mock.read.return_value = json.dumps({"message": {"content": content}}).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_ollama_check_available_finds_model(): + tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="gemma4:e4b") + ok, msg = p.check_available() + assert ok + assert msg == "ok" + + +def test_ollama_check_available_accepts_latest_suffix(): + tags = {"models": [{"name": "mymodel:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="mymodel") + ok, _ = p.check_available() + assert ok + + +def test_ollama_check_available_missing_model(): + tags = {"models": [{"name": "other:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="absent") + ok, msg = p.check_available() + assert not ok + assert "ollama pull absent" in msg + + +def test_ollama_check_available_unreachable(): + from urllib.error import URLError + + with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")): + p = OllamaProvider(model="gemma4:e4b") + ok, msg = p.check_available() + assert not ok + assert "Cannot reach Ollama" in msg + + +def test_ollama_classify_sends_json_format(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + captured["body"] = json.loads(req.data.decode()) + return _mock_ollama_chat_response('{"classifications": []}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OllamaProvider(model="gemma4:e4b") + resp = p.classify("sys", "user", json_mode=True) + + assert captured["body"]["format"] == "json" + assert captured["body"]["model"] == "gemma4:e4b" + assert captured["url"].endswith("/api/chat") + assert resp.provider == "ollama" + assert resp.text == '{"classifications": []}' + + +def test_ollama_classify_empty_content_raises(): + with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")): + p = OllamaProvider(model="x") + with pytest.raises(LLMError, match="Empty response"): + p.classify("s", "u") + + +# ── OpenAICompatProvider ──────────────────────────────────────────────── + + +def _mock_openai_response(content: str): + mock = MagicMock() + payload = {"choices": [{"message": {"content": content}}]} + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_openai_compat_resolves_url_with_v1_suffix(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h:1234") + p.classify("s", "u") + assert captured["url"] == "http://h:1234/v1/chat/completions" + + +def test_openai_compat_resolves_url_with_existing_v1(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1") + p.classify("s", "u") + assert captured["url"] == "http://h:1234/v1/chat/completions" + + +def test_openai_compat_requires_endpoint(): + p = OpenAICompatProvider(model="x") + with pytest.raises(LLMError, match="requires --llm-endpoint"): + p.classify("s", "u") + + +def test_openai_compat_sends_authorization_when_key_present(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["auth"] = req.get_header("Authorization") + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa") + p.classify("s", "u") + assert captured["auth"] == "Bearer sk-aaa" + + +def test_openai_compat_uses_env_var_fallback(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env") + p = OpenAICompatProvider(model="x", endpoint="http://h") + assert p.api_key == "sk-from-env" + + +def test_openai_compat_sends_response_format_json(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["body"] = json.loads(req.data.decode()) + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h") + p.classify("s", "u", json_mode=True) + assert captured["body"]["response_format"] == {"type": "json_object"} + + +def test_openai_compat_unexpected_shape_raises(): + mock = MagicMock() + mock.read.return_value = b'{"nothing": "here"}' + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OpenAICompatProvider(model="x", endpoint="http://h") + with pytest.raises(LLMError, match="Unexpected response shape"): + p.classify("s", "u") + + +# ── AnthropicProvider ─────────────────────────────────────────────────── + + +def _mock_anthropic_response(text: str): + mock = MagicMock() + payload = {"content": [{"type": "text", "text": text}]} + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_anthropic_requires_api_key(monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + p = AnthropicProvider(model="claude-haiku") + ok, msg = p.check_available() + assert not ok + assert "ANTHROPIC_API_KEY" in msg + + +def test_anthropic_reads_env_key(monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env") + p = AnthropicProvider(model="claude-haiku") + assert p.api_key == "sk-ant-env" + ok, _ = p.check_available() + assert ok + + +def test_anthropic_classify_sends_version_and_key(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["api_key"] = req.get_header("X-api-key") + captured["version"] = req.get_header("Anthropic-version") + return _mock_anthropic_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc") + resp = p.classify("s", "u") + assert captured["api_key"] == "sk-ant-abc" + assert captured["version"] == AnthropicProvider.API_VERSION + assert resp.text == '{"ok": true}' + + +def test_anthropic_joins_multiple_text_blocks(): + mock = MagicMock() + payload = { + "content": [ + {"type": "text", "text": "part one. "}, + {"type": "text", "text": "part two."}, + ] + } + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = AnthropicProvider(model="claude-haiku", api_key="sk-ant") + resp = p.classify("s", "u") + assert resp.text == "part one. part two." + + +def test_anthropic_no_key_raises_on_classify(monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + p = AnthropicProvider(model="claude-haiku") + with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"): + p.classify("s", "u") From 10a743d5d83ef561f8dfccd5079eac9cdd6fe014 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:46:59 -0300 Subject: [PATCH 3/8] feat(llm): interactive entity refinement with batching and cancellation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Takes the candidate set produced by phase-1 detection (manifests, git authors, regex on prose) and asks an LLM to reclassify each candidate as PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS. Scale approach: never feed the raw corpus to the LLM. For each candidate, collect up to 3 context lines from sampled prose, cap each at 240 chars, batch 25 candidates per call. Keeps total input around 50-100K tokens even on large corpora and completes in a few minutes on a 4B local model. Interactive UX: - Stderr progress bar with the current candidate name, updates per-batch. - Ctrl-C interrupts cleanly: returns a RefineResult with `cancelled=True` and whatever was classified before the interrupt. The partial result is safe to pass straight to confirm_entities. - Per-batch errors (transport, parse) are recorded in `errors` and don't abort the whole run. Refinement scope: only `uncertain` and low-confidence `projects` entries are sent. Manifest-backed projects (conf >= 0.95) and git- authored people are already authoritative and skip the LLM. Response parser is defensive — accepts `label` or `type` keys, lowercase/uppercase variants, top-level list or wrapped object, and strips markdown code fences. Unknown labels become AMBIGUOUS so the user reviews them rather than silently accepting a bad classification. `collect_corpus_text` provides a simple stratified prose sampler (recent first, capped per-file) so callers don't need to build their own corpus window. 28 tests with a FakeProvider (no network). Covers context collection, prompt building, response parsing variants, classification apply, end-to-end refine, and Ctrl-C partial-result behavior. --- mempalace/llm_refine.py | 368 ++++++++++++++++++++++++++++++++ tests/test_llm_refine.py | 446 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 814 insertions(+) create mode 100644 mempalace/llm_refine.py create mode 100644 tests/test_llm_refine.py diff --git a/mempalace/llm_refine.py b/mempalace/llm_refine.py new file mode 100644 index 0000000..91a950c --- /dev/null +++ b/mempalace/llm_refine.py @@ -0,0 +1,368 @@ +""" +llm_refine.py — Optional LLM refinement of regex-detected entities. + +Takes the candidate set produced by phase-1 detection (manifests, git +authors, regex on prose) and asks an LLM to reclassify each candidate as +PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS. + +Design constraints: +- Opt-in. Default init path never imports this module. +- Local-first by default (Ollama). +- Interactive UX: visible progress, clean cancellation (Ctrl-C returns + whatever was classified before the interrupt). +- Don't feed the raw corpus to the LLM — feed candidates + a few sampled + context lines each. Keeps total input to ~50-100K tokens even for huge + prose corpora. + +Public: + refine_entities(detected, corpus_text, provider, ...) -> dict +""" + +from __future__ import annotations + +import json +import re +import sys +from dataclasses import dataclass + +from mempalace.llm_client import LLMError, LLMProvider + + +BATCH_SIZE = 25 # candidates per LLM call; tuned for 4B local models +CONTEXT_LINES_PER_CANDIDATE = 3 +CONTEXT_WINDOW_CHARS = 240 # max chars per context line to keep tokens bounded + +# Valid labels the LLM is allowed to return. Anything else is treated as +# AMBIGUOUS so the user reviews it. +VALID_LABELS = {"PERSON", "PROJECT", "TOPIC", "COMMON_WORD", "AMBIGUOUS"} + + +SYSTEM_PROMPT = """You are helping organize a user's memory palace by classifying capitalized tokens found in their files. + +For each candidate, pick exactly ONE label: +- PERSON: a specific real person the user knows (colleague, family, character they write about) +- PROJECT: a named product, codebase, or effort the user works on +- TOPIC: a recurring theme or subject (not a person, not a project) — cities, technologies, concepts +- COMMON_WORD: an English word, verb, or fragment that isn't a named entity at all (e.g. "Created", "Before", "Never") +- AMBIGUOUS: context is insufficient to decide between two of the above + +Use the provided context lines to disambiguate. A capitalized word that only appears in metadata ("Created: 2026-04-24") is COMMON_WORD. A name that appears with pronouns and dialogue is PERSON. + +Respond with JSON only. Schema: +{"classifications": [{"name": "", "label": "