diff --git a/mempalace/query_sanitizer.py b/mempalace/query_sanitizer.py index f86a621..b320e9c 100644 --- a/mempalace/query_sanitizer.py +++ b/mempalace/query_sanitizer.py @@ -69,11 +69,15 @@ def sanitize_query(raw_query: str) -> dict: def _strip_wrapping_quotes(candidate: str) -> str: candidate = candidate.strip() - while candidate[:1] in {"'", '"'} and candidate[-1:] in {"'", '"'}: - candidate = candidate.strip("\"'") + while len(candidate) >= 2 and candidate[:1] in {"'", '"'} and candidate[-1:] in {"'", '"'}: + candidate = candidate[1:-1].strip() if not candidate: return "" - return candidate.strip("\"'") + if candidate[:1] in {"'", '"'}: + candidate = candidate[1:].strip() + if candidate[-1:] in {"'", '"'}: + candidate = candidate[:-1].strip() + return candidate def _trim_candidate(candidate: str) -> str: candidate = _strip_wrapping_quotes(candidate) diff --git a/tests/test_query_sanitizer.py b/tests/test_query_sanitizer.py index 015a8a1..dd96cfa 100644 --- a/tests/test_query_sanitizer.py +++ b/tests/test_query_sanitizer.py @@ -112,6 +112,7 @@ class TestTailSentence: query = ("Prefix text " * 30) + '\n"' + ("x" * 260) + '"' result = sanitize_query(query) assert result["method"] == "tail_sentence" + assert result["clean_query"] == "x" * MAX_QUERY_LENGTH assert not result["clean_query"].startswith('"') assert not result["clean_query"].endswith('"') assert len(result["clean_query"]) <= MAX_QUERY_LENGTH