test: cover embedding device fallback and bounded upserts
Agent-Logs-Url: https://github.com/MemPalace/mempalace/sessions/3213a67a-6871-4bb2-9ae0-23fa11001a22 Co-authored-by: igorls <4753812+igorls@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a4868a3589
commit
fbd0904799
@@ -55,6 +55,7 @@ CONVO_EXTENSIONS = {
|
|||||||
|
|
||||||
MIN_CHUNK_SIZE = 30
|
MIN_CHUNK_SIZE = 30
|
||||||
CHUNK_SIZE = 800 # chars per drawer — align with miner.py
|
CHUNK_SIZE = 800 # chars per drawer — align with miner.py
|
||||||
|
DRAWER_UPSERT_BATCH_SIZE = 1000
|
||||||
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB — skip files larger than this.
|
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB — skip files larger than this.
|
||||||
# Matches miner.py at 500 MB. Long Claude Code sessions, multi-year
|
# Matches miner.py at 500 MB. Long Claude Code sessions, multi-year
|
||||||
# ChatGPT exports, and lifetime Slack dumps routinely exceed 10 MB; the
|
# ChatGPT exports, and lifetime Slack dumps routinely exceed 10 MB; the
|
||||||
@@ -332,15 +333,16 @@ def _file_chunks_locked(collection, source_file, chunks, wing, room, agent, extr
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Batch the whole file into one upsert so the embedding model runs
|
# Batch chunks into bounded upserts so large transcripts keep most of
|
||||||
# a single forward pass for all chunks — dramatically faster than
|
# the embedding speedup without one huge Chroma/SQLite request. Keep
|
||||||
# one call per chunk, especially on GPU where per-call overhead
|
# one filed_at per source file so all transcript drawers share an
|
||||||
# dominates over the actual matmul.
|
# ingest timestamp.
|
||||||
|
filed_at = datetime.now().isoformat()
|
||||||
|
for batch_start in range(0, len(chunks), DRAWER_UPSERT_BATCH_SIZE):
|
||||||
batch_docs: list = []
|
batch_docs: list = []
|
||||||
batch_ids: list = []
|
batch_ids: list = []
|
||||||
batch_metas: list = []
|
batch_metas: list = []
|
||||||
filed_at = datetime.now().isoformat()
|
for chunk in chunks[batch_start : batch_start + DRAWER_UPSERT_BATCH_SIZE]:
|
||||||
for chunk in chunks:
|
|
||||||
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
||||||
if extract_mode == "general":
|
if extract_mode == "general":
|
||||||
room_counts_delta[chunk_room] += 1
|
room_counts_delta[chunk_room] += 1
|
||||||
@@ -361,15 +363,13 @@ def _file_chunks_locked(collection, source_file, chunks, wing, room, agent, extr
|
|||||||
"normalize_version": NORMALIZE_VERSION,
|
"normalize_version": NORMALIZE_VERSION,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_docs:
|
|
||||||
try:
|
try:
|
||||||
collection.upsert(
|
collection.upsert(
|
||||||
documents=batch_docs,
|
documents=batch_docs,
|
||||||
ids=batch_ids,
|
ids=batch_ids,
|
||||||
metadatas=batch_metas,
|
metadatas=batch_metas,
|
||||||
)
|
)
|
||||||
drawers_added = len(batch_docs)
|
drawers_added += len(batch_docs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "already exists" not in str(e).lower():
|
if "already exists" not in str(e).lower():
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -32,6 +32,12 @@ _PROVIDER_MAP = {
|
|||||||
"dml": ["DmlExecutionProvider", "CPUExecutionProvider"],
|
"dml": ["DmlExecutionProvider", "CPUExecutionProvider"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_DEVICE_EXTRA = {
|
||||||
|
"cuda": "mempalace[gpu]",
|
||||||
|
"coreml": "mempalace[coreml]",
|
||||||
|
"dml": "mempalace[dml]",
|
||||||
|
}
|
||||||
|
|
||||||
_AUTO_ORDER = [
|
_AUTO_ORDER = [
|
||||||
("CUDAExecutionProvider", "cuda"),
|
("CUDAExecutionProvider", "cuda"),
|
||||||
("CoreMLExecutionProvider", "coreml"),
|
("CoreMLExecutionProvider", "coreml"),
|
||||||
@@ -76,11 +82,13 @@ def _resolve_providers(device: str) -> tuple[list, str]:
|
|||||||
|
|
||||||
if preferred not in available:
|
if preferred not in available:
|
||||||
if device not in _WARNED:
|
if device not in _WARNED:
|
||||||
|
extra = _DEVICE_EXTRA.get(device, "the matching mempalace extra for your device")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"embedding_device=%r requested but %s is not installed — "
|
"embedding_device=%r requested but %s is not installed — "
|
||||||
"falling back to CPU. Install mempalace[gpu] for CUDA.",
|
"falling back to CPU. Install %s.",
|
||||||
device,
|
device,
|
||||||
preferred,
|
preferred,
|
||||||
|
extra,
|
||||||
)
|
)
|
||||||
_WARNED.add(device)
|
_WARNED.add(device)
|
||||||
return (["CPUExecutionProvider"], "cpu")
|
return (["CPUExecutionProvider"], "cpu")
|
||||||
|
|||||||
+9
-9
@@ -65,6 +65,7 @@ SKIP_FILENAMES = {
|
|||||||
CHUNK_SIZE = 800 # chars per drawer
|
CHUNK_SIZE = 800 # chars per drawer
|
||||||
CHUNK_OVERLAP = 100 # overlap between chunks
|
CHUNK_OVERLAP = 100 # overlap between chunks
|
||||||
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
||||||
|
DRAWER_UPSERT_BATCH_SIZE = 1000
|
||||||
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB — skip files larger than this.
|
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB — skip files larger than this.
|
||||||
# Long Claude Code sessions and large transcript exports routinely exceed
|
# Long Claude Code sessions and large transcript exports routinely exceed
|
||||||
# 10 MB. The cap exists as a defensive rail against pathological binary
|
# 10 MB. The cap exists as a defensive rail against pathological binary
|
||||||
@@ -748,19 +749,21 @@ def process_file(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Batch all chunks of this file into a single upsert so the embedding
|
# Batch chunks into bounded upserts so the embedding model sees many
|
||||||
# model runs one forward pass over the whole file instead of N passes
|
# chunks per forward pass without building one huge Chroma/SQLite
|
||||||
# of one chunk each. On CPU this is typically a 10-30x speedup; on
|
# request for pathological files. A bad chunk can fail its sub-batch;
|
||||||
# GPU the speedup is larger because per-call overhead dominates.
|
# that is the deliberate trade-off for amortizing embedding overhead.
|
||||||
try:
|
try:
|
||||||
source_mtime = os.path.getmtime(source_file)
|
source_mtime = os.path.getmtime(source_file)
|
||||||
except OSError:
|
except OSError:
|
||||||
source_mtime = None
|
source_mtime = None
|
||||||
|
|
||||||
|
drawers_added = 0
|
||||||
|
for batch_start in range(0, len(chunks), DRAWER_UPSERT_BATCH_SIZE):
|
||||||
batch_docs: list = []
|
batch_docs: list = []
|
||||||
batch_ids: list = []
|
batch_ids: list = []
|
||||||
batch_metas: list = []
|
batch_metas: list = []
|
||||||
for chunk in chunks:
|
for chunk in chunks[batch_start : batch_start + DRAWER_UPSERT_BATCH_SIZE]:
|
||||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
|
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
|
||||||
batch_docs.append(chunk["content"])
|
batch_docs.append(chunk["content"])
|
||||||
batch_ids.append(drawer_id)
|
batch_ids.append(drawer_id)
|
||||||
@@ -775,15 +778,12 @@ def process_file(
|
|||||||
source_mtime,
|
source_mtime,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
drawers_added = 0
|
|
||||||
if batch_docs:
|
|
||||||
collection.upsert(
|
collection.upsert(
|
||||||
documents=batch_docs,
|
documents=batch_docs,
|
||||||
ids=batch_ids,
|
ids=batch_ids,
|
||||||
metadatas=batch_metas,
|
metadatas=batch_metas,
|
||||||
)
|
)
|
||||||
drawers_added = len(batch_docs)
|
drawers_added += len(batch_docs)
|
||||||
|
|
||||||
# Build closet — the searchable index pointing to these drawers.
|
# Build closet — the searchable index pointing to these drawers.
|
||||||
# Purge first: a re-mine (mtime change or normalize_version bump) must
|
# Purge first: a re-mine (mtime change or normalize_version bump) must
|
||||||
|
|||||||
@@ -20,6 +20,32 @@ def test_config_from_file():
|
|||||||
assert cfg.palace_path == "/custom/palace"
|
assert cfg.palace_path == "/custom/palace"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_device_defaults_to_auto():
|
||||||
|
cfg = MempalaceConfig(config_dir=tempfile.mkdtemp())
|
||||||
|
assert cfg.embedding_device == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_device_from_config_is_normalized():
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
with open(os.path.join(tmpdir, "config.json"), "w") as f:
|
||||||
|
json.dump({"embedding_device": " CUDA "}, f)
|
||||||
|
|
||||||
|
cfg = MempalaceConfig(config_dir=tmpdir)
|
||||||
|
assert cfg.embedding_device == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_device_env_overrides_config():
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
with open(os.path.join(tmpdir, "config.json"), "w") as f:
|
||||||
|
json.dump({"embedding_device": "cpu"}, f)
|
||||||
|
os.environ["MEMPALACE_EMBEDDING_DEVICE"] = " CoreML "
|
||||||
|
try:
|
||||||
|
cfg = MempalaceConfig(config_dir=tmpdir)
|
||||||
|
assert cfg.embedding_device == "coreml"
|
||||||
|
finally:
|
||||||
|
del os.environ["MEMPALACE_EMBEDDING_DEVICE"]
|
||||||
|
|
||||||
|
|
||||||
def test_env_override():
|
def test_env_override():
|
||||||
raw = "/env/palace"
|
raw = "/env/palace"
|
||||||
os.environ["MEMPALACE_PALACE_PATH"] = raw
|
os.environ["MEMPALACE_PALACE_PATH"] = raw
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
|
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
|
||||||
|
|
||||||
from mempalace.convo_miner import (
|
from mempalace.convo_miner import (
|
||||||
|
_file_chunks_locked,
|
||||||
chunk_exchanges,
|
chunk_exchanges,
|
||||||
detect_convo_room,
|
detect_convo_room,
|
||||||
scan_convos,
|
scan_convos,
|
||||||
@@ -111,3 +112,38 @@ class TestScanConvos:
|
|||||||
def test_scan_empty_dir(self, tmp_path):
|
def test_scan_empty_dir(self, tmp_path):
|
||||||
files = scan_convos(str(tmp_path))
|
files = scan_convos(str(tmp_path))
|
||||||
assert files == []
|
assert files == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileChunksLocked:
|
||||||
|
def test_uses_bounded_upsert_batches(self, monkeypatch):
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import mempalace.convo_miner as convo_miner
|
||||||
|
|
||||||
|
class FakeCol:
|
||||||
|
def __init__(self):
|
||||||
|
self.batch_sizes = []
|
||||||
|
|
||||||
|
def delete(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def upsert(self, documents, ids, metadatas):
|
||||||
|
self.batch_sizes.append(len(documents))
|
||||||
|
|
||||||
|
chunks = [{"content": f"chunk {i} " * 20, "chunk_index": i} for i in range(5)]
|
||||||
|
col = FakeCol()
|
||||||
|
monkeypatch.setattr(convo_miner, "DRAWER_UPSERT_BATCH_SIZE", 2)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
convo_miner, "file_already_mined", lambda collection, source_file: False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(convo_miner, "mine_lock", lambda source_file: contextlib.nullcontext())
|
||||||
|
monkeypatch.setattr(convo_miner, "_detect_hall_cached", lambda content: "conversations")
|
||||||
|
|
||||||
|
drawers, room_counts, skipped = _file_chunks_locked(
|
||||||
|
col, "chat.txt", chunks, "wing", "general", "agent", "exchange"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert drawers == 5
|
||||||
|
assert dict(room_counts) == {}
|
||||||
|
assert skipped is False
|
||||||
|
assert col.batch_sizes == [2, 2, 1]
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import builtins
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mempalace.embedding as embedding
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_embedding_state():
|
||||||
|
embedding._EF_CACHE.clear()
|
||||||
|
embedding._WARNED.clear()
|
||||||
|
yield
|
||||||
|
embedding._EF_CACHE.clear()
|
||||||
|
embedding._WARNED.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_picks_cuda(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"onnxruntime.get_available_providers",
|
||||||
|
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("auto") == (
|
||||||
|
["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||||
|
"cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_falls_to_cpu(monkeypatch):
|
||||||
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("auto") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cuda_missing_warns_with_gpu_extra(monkeypatch, caplog):
|
||||||
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("cuda") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
assert "mempalace[gpu]" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_coreml_missing_warns_with_coreml_extra(monkeypatch, caplog):
|
||||||
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("coreml") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
assert "mempalace[coreml]" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_dml_missing_warns_with_dml_extra(monkeypatch, caplog):
|
||||||
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("dml") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
assert "mempalace[dml]" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_device_warns_once(monkeypatch, caplog):
|
||||||
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("bogus") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
assert embedding._resolve_providers("bogus") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
assert caplog.text.count("Unknown embedding_device") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_onnxruntime_import_error_falls_back_to_cpu(monkeypatch):
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
def fake_import(name, *args, **kwargs):
|
||||||
|
if name == "onnxruntime":
|
||||||
|
raise ImportError("missing")
|
||||||
|
return real_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||||
|
|
||||||
|
assert embedding._resolve_providers("cuda") == (["CPUExecutionProvider"], "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_function_caches_by_resolved_provider_tuple(monkeypatch):
|
||||||
|
class DummyEF:
|
||||||
|
def __init__(self, preferred_providers):
|
||||||
|
self.preferred_providers = preferred_providers
|
||||||
|
|
||||||
|
monkeypatch.setattr(embedding, "_build_ef_class", lambda: DummyEF)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
embedding, "_resolve_providers", lambda device: (["CPUExecutionProvider"], "cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
first = embedding.get_embedding_function("cpu")
|
||||||
|
second = embedding.get_embedding_function("auto")
|
||||||
|
|
||||||
|
assert first is second
|
||||||
|
assert first.preferred_providers == ["CPUExecutionProvider"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_describe_device_uses_resolved_effective_device(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
embedding,
|
||||||
|
"_resolve_providers",
|
||||||
|
lambda device: (["CUDAExecutionProvider", "CPUExecutionProvider"], "cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding.describe_device("auto") == "cuda"
|
||||||
@@ -383,6 +383,46 @@ def test_status_handles_none_metadata_without_crash(tmp_path, capsys):
|
|||||||
assert "WING: proj" in out
|
assert "WING: proj" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_uses_bounded_upsert_batches(tmp_path, monkeypatch):
|
||||||
|
from mempalace import miner
|
||||||
|
|
||||||
|
class FakeCol:
|
||||||
|
def __init__(self):
|
||||||
|
self.batch_sizes = []
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
return {"ids": []}
|
||||||
|
|
||||||
|
def delete(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def upsert(self, documents, ids, metadatas):
|
||||||
|
self.batch_sizes.append(len(documents))
|
||||||
|
|
||||||
|
source = tmp_path / "src.py"
|
||||||
|
source.write_text("print('hello')\n" * 20, encoding="utf-8")
|
||||||
|
chunks = [{"content": f"chunk {i} " * 20, "chunk_index": i} for i in range(5)]
|
||||||
|
col = FakeCol()
|
||||||
|
monkeypatch.setattr(miner, "DRAWER_UPSERT_BATCH_SIZE", 2)
|
||||||
|
monkeypatch.setattr(miner, "chunk_text", lambda content, source_file: chunks)
|
||||||
|
monkeypatch.setattr(miner, "detect_hall", lambda content: "code")
|
||||||
|
monkeypatch.setattr(miner, "_extract_entities_for_metadata", lambda content: "")
|
||||||
|
|
||||||
|
drawers, room = miner.process_file(
|
||||||
|
source,
|
||||||
|
tmp_path,
|
||||||
|
col,
|
||||||
|
"wing",
|
||||||
|
[{"name": "general", "description": "General"}],
|
||||||
|
"agent",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert drawers == 5
|
||||||
|
assert room == "general"
|
||||||
|
assert col.batch_sizes == [2, 2, 1]
|
||||||
|
|
||||||
|
|
||||||
# ── normalize_version schema gate ───────────────────────────────────────
|
# ── normalize_version schema gate ───────────────────────────────────────
|
||||||
#
|
#
|
||||||
# When the normalization pipeline changes shape (e.g., strip_noise lands),
|
# When the normalization pipeline changes shape (e.g., strip_noise lands),
|
||||||
|
|||||||
Reference in New Issue
Block a user