fbd0904799
Agent-Logs-Url: https://github.com/MemPalace/mempalace/sessions/3213a67a-6871-4bb2-9ae0-23fa11001a22 Co-authored-by: igorls <4753812+igorls@users.noreply.github.com>
102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
import builtins
|
|
|
|
import pytest
|
|
|
|
import mempalace.embedding as embedding
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_embedding_state():
|
|
embedding._EF_CACHE.clear()
|
|
embedding._WARNED.clear()
|
|
yield
|
|
embedding._EF_CACHE.clear()
|
|
embedding._WARNED.clear()
|
|
|
|
|
|
def test_auto_picks_cuda(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"onnxruntime.get_available_providers",
|
|
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
)
|
|
|
|
assert embedding._resolve_providers("auto") == (
|
|
["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
"cuda",
|
|
)
|
|
|
|
|
|
def test_auto_falls_to_cpu(monkeypatch):
|
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
|
|
|
assert embedding._resolve_providers("auto") == (["CPUExecutionProvider"], "cpu")
|
|
|
|
|
|
def test_cuda_missing_warns_with_gpu_extra(monkeypatch, caplog):
|
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
|
|
|
assert embedding._resolve_providers("cuda") == (["CPUExecutionProvider"], "cpu")
|
|
assert "mempalace[gpu]" in caplog.text
|
|
|
|
|
|
def test_coreml_missing_warns_with_coreml_extra(monkeypatch, caplog):
|
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
|
|
|
assert embedding._resolve_providers("coreml") == (["CPUExecutionProvider"], "cpu")
|
|
assert "mempalace[coreml]" in caplog.text
|
|
|
|
|
|
def test_dml_missing_warns_with_dml_extra(monkeypatch, caplog):
|
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
|
|
|
assert embedding._resolve_providers("dml") == (["CPUExecutionProvider"], "cpu")
|
|
assert "mempalace[dml]" in caplog.text
|
|
|
|
|
|
def test_unknown_device_warns_once(monkeypatch, caplog):
|
|
monkeypatch.setattr("onnxruntime.get_available_providers", lambda: ["CPUExecutionProvider"])
|
|
|
|
assert embedding._resolve_providers("bogus") == (["CPUExecutionProvider"], "cpu")
|
|
assert embedding._resolve_providers("bogus") == (["CPUExecutionProvider"], "cpu")
|
|
assert caplog.text.count("Unknown embedding_device") == 1
|
|
|
|
|
|
def test_onnxruntime_import_error_falls_back_to_cpu(monkeypatch):
|
|
real_import = builtins.__import__
|
|
|
|
def fake_import(name, *args, **kwargs):
|
|
if name == "onnxruntime":
|
|
raise ImportError("missing")
|
|
return real_import(name, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
|
|
|
assert embedding._resolve_providers("cuda") == (["CPUExecutionProvider"], "cpu")
|
|
|
|
|
|
def test_get_embedding_function_caches_by_resolved_provider_tuple(monkeypatch):
|
|
class DummyEF:
|
|
def __init__(self, preferred_providers):
|
|
self.preferred_providers = preferred_providers
|
|
|
|
monkeypatch.setattr(embedding, "_build_ef_class", lambda: DummyEF)
|
|
monkeypatch.setattr(
|
|
embedding, "_resolve_providers", lambda device: (["CPUExecutionProvider"], "cpu")
|
|
)
|
|
|
|
first = embedding.get_embedding_function("cpu")
|
|
second = embedding.get_embedding_function("auto")
|
|
|
|
assert first is second
|
|
assert first.preferred_providers == ["CPUExecutionProvider"]
|
|
|
|
|
|
def test_describe_device_uses_resolved_effective_device(monkeypatch):
|
|
monkeypatch.setattr(
|
|
embedding,
|
|
"_resolve_providers",
|
|
lambda device: (["CUDAExecutionProvider", "CPUExecutionProvider"], "cuda"),
|
|
)
|
|
|
|
assert embedding.describe_device("auto") == "cuda"
|