Initial MVP
This commit is contained in:
@@ -0,0 +1,131 @@
|
||||
"""Tests for configuration loading, merging, and ENV overrides."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from kb_search.config import (
|
||||
DEFAULTS,
|
||||
_deep_merge,
|
||||
_get_nested,
|
||||
_set_nested,
|
||||
config_with_sources,
|
||||
load_config,
|
||||
save_config_value,
|
||||
)
|
||||
|
||||
|
||||
def test_deep_merge_basic():
|
||||
base = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
override = {"b": {"c": 99}}
|
||||
result = _deep_merge(base, override)
|
||||
assert result == {"a": 1, "b": {"c": 99, "d": 3}}
|
||||
|
||||
|
||||
def test_deep_merge_new_keys():
|
||||
base = {"a": 1}
|
||||
override = {"b": 2}
|
||||
result = _deep_merge(base, override)
|
||||
assert result == {"a": 1, "b": 2}
|
||||
|
||||
|
||||
def test_deep_merge_does_not_mutate():
|
||||
base = {"a": {"b": 1}}
|
||||
override = {"a": {"b": 2}}
|
||||
_deep_merge(base, override)
|
||||
assert base["a"]["b"] == 1
|
||||
|
||||
|
||||
def test_set_nested():
|
||||
d = {}
|
||||
_set_nested(d, "a.b.c", 42)
|
||||
assert d == {"a": {"b": {"c": 42}}}
|
||||
|
||||
|
||||
def test_get_nested():
|
||||
d = {"a": {"b": {"c": 42}}}
|
||||
assert _get_nested(d, "a.b.c") == 42
|
||||
assert _get_nested(d, "a.b.x", "missing") == "missing"
|
||||
assert _get_nested(d, "x.y.z") is None
|
||||
|
||||
|
||||
def test_load_config_defaults(tmp_path):
|
||||
"""With no config file, returns defaults."""
|
||||
cfg = load_config(tmp_path / "nonexistent.yaml")
|
||||
assert cfg["embedding"]["model"] == "all-MiniLM-L6-v2"
|
||||
assert cfg["search"]["default_top"] == 10
|
||||
assert cfg["chunking"]["pdf"]["strategy"] == "hierarchy"
|
||||
|
||||
|
||||
def test_load_config_yaml_override(tmp_path):
|
||||
"""YAML values override defaults."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"embedding": {"model": "nomic-embed-text"}}))
|
||||
cfg = load_config(config_path)
|
||||
assert cfg["embedding"]["model"] == "nomic-embed-text"
|
||||
# Other defaults preserved
|
||||
assert cfg["search"]["default_top"] == 10
|
||||
|
||||
|
||||
def test_load_config_env_override(tmp_path, monkeypatch):
|
||||
"""ENV overrides both YAML and defaults."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"search": {"default_top": 20}}))
|
||||
monkeypatch.setenv("KB_DEFAULT_TOP", "50")
|
||||
cfg = load_config(config_path)
|
||||
assert cfg["search"]["default_top"] == 50
|
||||
|
||||
|
||||
def test_load_config_env_model(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("KB_MODEL", "bge-small-en-v1.5")
|
||||
cfg = load_config(tmp_path / "nonexistent.yaml")
|
||||
assert cfg["embedding"]["model"] == "bge-small-en-v1.5"
|
||||
|
||||
|
||||
def test_save_config_value(tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
save_config_value(config_path, "chunking.pdf.max_tokens", "2048")
|
||||
with open(config_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data["chunking"]["pdf"]["max_tokens"] == 2048
|
||||
|
||||
|
||||
def test_save_config_value_bool(tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
save_config_value(config_path, "chunking.code.include_context", "false")
|
||||
with open(config_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data["chunking"]["code"]["include_context"] is False
|
||||
|
||||
|
||||
def test_save_config_preserves_existing(tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"embedding": {"model": "custom"}}))
|
||||
save_config_value(config_path, "search.default_top", "20")
|
||||
with open(config_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data["embedding"]["model"] == "custom"
|
||||
assert data["search"]["default_top"] == 20
|
||||
|
||||
|
||||
def test_config_with_sources_defaults(tmp_path, monkeypatch):
|
||||
entries = config_with_sources(tmp_path / "nonexistent.yaml")
|
||||
sources = {k: s for k, _, s in entries}
|
||||
assert sources["embedding.model"] == "default"
|
||||
|
||||
|
||||
def test_config_with_sources_yaml(tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"embedding": {"model": "custom"}}))
|
||||
entries = config_with_sources(config_path)
|
||||
sources = {k: s for k, _, s in entries}
|
||||
assert sources["embedding.model"] == "config.yaml"
|
||||
|
||||
|
||||
def test_config_with_sources_env(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("KB_MODEL", "from-env")
|
||||
entries = config_with_sources(tmp_path / "nonexistent.yaml")
|
||||
sources = {k: s for k, _, s in entries}
|
||||
assert sources["embedding.model"] == "env (KB_MODEL)"
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Tests for database schema, FTS triggers, and config helpers."""
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from kb_search.database import (
|
||||
SCHEMA_VERSION,
|
||||
check_schema_version,
|
||||
get_connection,
|
||||
get_db_config,
|
||||
get_or_create_tag,
|
||||
hash_exists,
|
||||
init_schema,
|
||||
insert_chunk,
|
||||
insert_document,
|
||||
insert_embedding,
|
||||
recreate_vec_table,
|
||||
run_migrations,
|
||||
set_db_config,
|
||||
tag_document,
|
||||
untag_document,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path):
|
||||
"""Provide an initialised in-memory-like DB."""
|
||||
db_path = tmp_path / "test.db"
|
||||
conn = get_connection(db_path)
|
||||
init_schema(conn, embedding_dim=384)
|
||||
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_schema_creation(db):
|
||||
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
||||
assert "documents" in tables
|
||||
assert "chunks" in tables
|
||||
assert "tags" in tables
|
||||
assert "document_tags" in tables
|
||||
assert "config" in tables
|
||||
|
||||
|
||||
def test_fts_table_exists(db):
|
||||
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
||||
assert "chunks_fts" in tables
|
||||
|
||||
|
||||
def test_vec_table_exists(db):
|
||||
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
||||
assert "chunks_vec" in tables
|
||||
|
||||
|
||||
def test_config_get_set(db):
|
||||
set_db_config(db, "test_key", "test_value")
|
||||
assert get_db_config(db, "test_key") == "test_value"
|
||||
|
||||
|
||||
def test_config_get_default(db):
|
||||
assert get_db_config(db, "nonexistent", "fallback") == "fallback"
|
||||
|
||||
|
||||
def test_config_upsert(db):
|
||||
set_db_config(db, "key", "v1")
|
||||
set_db_config(db, "key", "v2")
|
||||
assert get_db_config(db, "key") == "v2"
|
||||
|
||||
|
||||
def test_schema_version(db):
|
||||
assert check_schema_version(db) == SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_insert_document(db):
|
||||
doc_id = insert_document(db, "Test Doc", "/path/test.pdf", "abc123", "pdf")
|
||||
db.commit()
|
||||
row = db.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
|
||||
assert row["title"] == "Test Doc"
|
||||
assert row["doc_type"] == "pdf"
|
||||
assert row["content_hash"] == "abc123"
|
||||
|
||||
|
||||
def test_insert_chunk_with_fts_sync(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash1", "note")
|
||||
chunk_id = insert_chunk(db, doc_id, 0, "This is searchable text about Python programming")
|
||||
db.commit()
|
||||
|
||||
# FTS should find it
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'python'"
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0] == chunk_id
|
||||
|
||||
|
||||
def test_fts_delete_trigger(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash2", "note")
|
||||
chunk_id = insert_chunk(db, doc_id, 0, "unique_keyword_xyz")
|
||||
db.commit()
|
||||
|
||||
db.execute("DELETE FROM chunks WHERE id = ?", (chunk_id,))
|
||||
db.commit()
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'unique_keyword_xyz'"
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_fts_update_trigger(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash3", "note")
|
||||
chunk_id = insert_chunk(db, doc_id, 0, "old_content_abc")
|
||||
db.commit()
|
||||
|
||||
db.execute("UPDATE chunks SET text = 'new_content_def' WHERE id = ?", (chunk_id,))
|
||||
db.commit()
|
||||
|
||||
old = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'old_content_abc'").fetchall()
|
||||
new = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'new_content_def'").fetchall()
|
||||
assert len(old) == 0
|
||||
assert len(new) == 1
|
||||
|
||||
|
||||
def test_insert_embedding(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash4", "note")
|
||||
chunk_id = insert_chunk(db, doc_id, 0, "text")
|
||||
db.commit()
|
||||
|
||||
embedding = [0.1] * 384
|
||||
insert_embedding(db, chunk_id, embedding)
|
||||
db.commit()
|
||||
|
||||
row = db.execute("SELECT * FROM chunks_vec WHERE chunk_id = ?", (chunk_id,)).fetchone()
|
||||
assert row is not None
|
||||
|
||||
|
||||
def test_hash_exists(db):
|
||||
assert not hash_exists(db, "newhash")
|
||||
insert_document(db, "Doc", None, "newhash", "note")
|
||||
db.commit()
|
||||
assert hash_exists(db, "newhash")
|
||||
|
||||
|
||||
def test_tag_management(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash5", "pdf")
|
||||
db.commit()
|
||||
|
||||
tag_document(db, doc_id, ["git", "admin"])
|
||||
db.commit()
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
|
||||
"WHERE dt.document_id = ? ORDER BY t.name",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
assert [r["name"] for r in rows] == ["admin", "git"]
|
||||
|
||||
|
||||
def test_untag_document(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash6", "pdf")
|
||||
tag_document(db, doc_id, ["a", "b", "c"])
|
||||
db.commit()
|
||||
|
||||
untag_document(db, doc_id, ["b"])
|
||||
db.commit()
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
|
||||
"WHERE dt.document_id = ? ORDER BY t.name",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
assert [r["name"] for r in rows] == ["a", "c"]
|
||||
|
||||
|
||||
def test_tags_are_lowercase(db):
|
||||
tag_id = get_or_create_tag(db, "MyTag")
|
||||
db.commit()
|
||||
row = db.execute("SELECT name FROM tags WHERE id = ?", (tag_id,)).fetchone()
|
||||
assert row["name"] == "mytag"
|
||||
|
||||
|
||||
def test_recreate_vec_table(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash7", "note")
|
||||
chunk_id = insert_chunk(db, doc_id, 0, "text")
|
||||
insert_embedding(db, chunk_id, [0.1] * 384)
|
||||
db.commit()
|
||||
|
||||
recreate_vec_table(db, 768)
|
||||
# Old data gone, new dimension
|
||||
rows = db.execute("SELECT * FROM chunks_vec").fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_cascade_delete(db):
|
||||
doc_id = insert_document(db, "Doc", None, "hash8", "pdf")
|
||||
insert_chunk(db, doc_id, 0, "chunk text")
|
||||
tag_document(db, doc_id, ["test"])
|
||||
db.commit()
|
||||
|
||||
db.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
||||
db.commit()
|
||||
|
||||
assert db.execute("SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
|
||||
assert db.execute("SELECT COUNT(*) FROM document_tags WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for embedding model management."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from kb_search.embeddings import check_model_binding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conn():
|
||||
"""Mock DB connection with config values."""
|
||||
def make_conn(config_values=None):
|
||||
config_values = config_values or {}
|
||||
conn = MagicMock()
|
||||
def mock_execute(sql, params=None):
|
||||
if "SELECT value FROM config" in sql and params:
|
||||
key = params[0]
|
||||
val = config_values.get(key)
|
||||
row = MagicMock()
|
||||
row.__getitem__ = lambda self, k: val
|
||||
result = MagicMock()
|
||||
result.fetchone.return_value = row if val else None
|
||||
return result
|
||||
return MagicMock()
|
||||
conn.execute = mock_execute
|
||||
return conn
|
||||
return make_conn
|
||||
|
||||
|
||||
def test_model_binding_match(mock_conn):
|
||||
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
||||
cfg = {"embedding": {"model": "all-MiniLM-L6-v2"}}
|
||||
# Should not raise
|
||||
check_model_binding(conn, cfg)
|
||||
|
||||
|
||||
def test_model_binding_mismatch(mock_conn):
|
||||
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
||||
cfg = {"embedding": {"model": "nomic-embed-text"}}
|
||||
with pytest.raises(click.ClickException, match="Model mismatch"):
|
||||
check_model_binding(conn, cfg)
|
||||
|
||||
|
||||
def test_model_binding_no_db_model(mock_conn):
|
||||
conn = mock_conn({})
|
||||
cfg = {"embedding": {"model": "anything"}}
|
||||
# Should not raise when DB not yet initialised
|
||||
check_model_binding(conn, cfg)
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Tests for code chunking — Python, Bash, Go."""
|
||||
|
||||
from kb_search.ingest.code import chunk_code, _chunk_python, _chunk_bash, _chunk_go, _fixed_chunk
|
||||
|
||||
CFG = {"chunking": {"code": {"strategy": "ast", "include_context": True, "max_tokens": 1024}}}
|
||||
|
||||
|
||||
class TestPythonChunking:
|
||||
def test_functions(self):
|
||||
code = '''
|
||||
def hello():
|
||||
"""Say hello."""
|
||||
print("hello")
|
||||
|
||||
def goodbye():
|
||||
"""Say goodbye."""
|
||||
print("bye")
|
||||
'''
|
||||
chunks = _chunk_python(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "hello"
|
||||
assert chunks[1]["metadata"]["symbol_name"] == "goodbye"
|
||||
|
||||
def test_class_with_methods(self):
|
||||
code = '''
|
||||
class MyClass:
|
||||
"""A test class."""
|
||||
|
||||
def method_a(self):
|
||||
pass
|
||||
|
||||
def method_b(self):
|
||||
pass
|
||||
'''
|
||||
chunks = _chunk_python(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "MyClass.method_a"
|
||||
assert chunks[1]["metadata"]["symbol_name"] == "MyClass.method_b"
|
||||
# Context should include class docstring
|
||||
assert "A test class" in chunks[0]["text"]
|
||||
|
||||
def test_class_without_methods(self):
|
||||
code = '''
|
||||
class Config:
|
||||
"""Configuration."""
|
||||
DEBUG = True
|
||||
PORT = 8080
|
||||
'''
|
||||
chunks = _chunk_python(code, include_context=True)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "Config"
|
||||
|
||||
def test_syntax_error_returns_empty(self):
|
||||
chunks = _chunk_python("def broken(:\n pass", include_context=True)
|
||||
assert chunks == []
|
||||
|
||||
def test_no_context(self):
|
||||
code = '''
|
||||
class Foo:
|
||||
"""Docstring."""
|
||||
def bar(self):
|
||||
pass
|
||||
'''
|
||||
chunks = _chunk_python(code, include_context=False)
|
||||
assert len(chunks) == 1
|
||||
assert "Docstring" not in chunks[0]["text"]
|
||||
|
||||
|
||||
class TestBashChunking:
|
||||
def test_function_keyword(self):
|
||||
code = '''#!/bin/bash
|
||||
|
||||
function deploy() {
|
||||
echo "deploying"
|
||||
}
|
||||
|
||||
function rollback() {
|
||||
echo "rolling back"
|
||||
}
|
||||
'''
|
||||
chunks = _chunk_bash(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "deploy"
|
||||
assert chunks[1]["metadata"]["symbol_name"] == "rollback"
|
||||
|
||||
def test_shorthand_syntax(self):
|
||||
code = '''
|
||||
setup() {
|
||||
echo "setup"
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "cleanup"
|
||||
}
|
||||
'''
|
||||
chunks = _chunk_bash(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
|
||||
def test_no_functions(self):
|
||||
code = "#!/bin/bash\necho hello\nexit 0"
|
||||
chunks = _chunk_bash(code, include_context=True)
|
||||
assert chunks == []
|
||||
|
||||
def test_with_preceding_comments(self):
|
||||
code = '''
|
||||
# Deploy to production
|
||||
# Requires valid credentials
|
||||
function deploy() {
|
||||
echo "deploying"
|
||||
}
|
||||
'''
|
||||
chunks = _chunk_bash(code, include_context=True)
|
||||
assert len(chunks) == 1
|
||||
assert "Deploy to production" in chunks[0]["text"]
|
||||
|
||||
|
||||
class TestGoChunking:
|
||||
def test_basic_funcs(self):
|
||||
code = '''package main
|
||||
|
||||
func main() {
|
||||
fmt.Println("hello")
|
||||
}
|
||||
|
||||
func helper() string {
|
||||
return "help"
|
||||
}
|
||||
'''
|
||||
chunks = _chunk_go(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "main"
|
||||
assert chunks[1]["metadata"]["symbol_name"] == "helper"
|
||||
|
||||
def test_method_receiver(self):
|
||||
code = '''
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() {
|
||||
}
|
||||
'''
|
||||
chunks = _chunk_go(code, include_context=True)
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["metadata"]["symbol_name"] == "Start"
|
||||
|
||||
def test_no_funcs(self):
|
||||
code = "package main\n\nvar x = 1"
|
||||
chunks = _chunk_go(code, include_context=True)
|
||||
assert chunks == []
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_language_uses_fixed(self):
|
||||
code = "line1\nline2\nline3"
|
||||
chunks = chunk_code(code, "ruby", CFG)
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_python_no_functions_uses_fixed(self):
|
||||
code = "x = 1\ny = 2\nprint(x + y)"
|
||||
chunks = chunk_code(code, "python", CFG)
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_fixed_strategy_config(self):
|
||||
cfg = {"chunking": {"code": {"strategy": "fixed", "max_tokens": 10}}}
|
||||
code = "\n".join(f"x_{i} = {i}" for i in range(50))
|
||||
chunks = chunk_code(code, "python", cfg)
|
||||
assert len(chunks) > 1
|
||||
|
||||
def test_empty_code(self):
|
||||
chunks = chunk_code("", "python", CFG)
|
||||
assert len(chunks) == 0
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Tests for file type detection, dedup, note creation."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from kb_search.ingest.detector import detect_type, is_supported
|
||||
from kb_search.ingest.note import auto_title, chunk_note
|
||||
|
||||
|
||||
class TestDetector:
|
||||
def test_pdf(self, tmp_path):
|
||||
assert detect_type(tmp_path / "doc.pdf") == ("pdf", None)
|
||||
|
||||
def test_markdown(self, tmp_path):
|
||||
assert detect_type(tmp_path / "notes.md") == ("markdown", None)
|
||||
|
||||
def test_txt(self, tmp_path):
|
||||
assert detect_type(tmp_path / "notes.txt") == ("markdown", None)
|
||||
|
||||
def test_python(self, tmp_path):
|
||||
assert detect_type(tmp_path / "main.py") == ("code", "python")
|
||||
|
||||
def test_bash(self, tmp_path):
|
||||
assert detect_type(tmp_path / "deploy.sh") == ("code", "bash")
|
||||
|
||||
def test_go(self, tmp_path):
|
||||
assert detect_type(tmp_path / "main.go") == ("code", "go")
|
||||
|
||||
def test_unsupported(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
detect_type(tmp_path / "archive.zip")
|
||||
|
||||
def test_force_type(self, tmp_path):
|
||||
assert detect_type(tmp_path / "data.txt", force_type="code", force_language="bash") == ("code", "bash")
|
||||
|
||||
def test_force_language_only(self, tmp_path):
|
||||
doc_type, lang = detect_type(tmp_path / "script.py", force_language="go")
|
||||
assert doc_type == "code"
|
||||
assert lang == "go"
|
||||
|
||||
def test_is_supported(self, tmp_path):
|
||||
assert is_supported(tmp_path / "test.pdf")
|
||||
assert is_supported(tmp_path / "test.py")
|
||||
assert not is_supported(tmp_path / "test.zip")
|
||||
|
||||
def test_case_insensitive(self, tmp_path):
|
||||
assert detect_type(tmp_path / "DOC.PDF") == ("pdf", None)
|
||||
|
||||
def test_image_files(self, tmp_path):
|
||||
assert detect_type(tmp_path / "scan.png") == ("pdf", None)
|
||||
assert detect_type(tmp_path / "photo.jpg") == ("pdf", None)
|
||||
|
||||
def test_docx(self, tmp_path):
|
||||
assert detect_type(tmp_path / "report.docx") == ("pdf", None)
|
||||
|
||||
|
||||
class TestNote:
|
||||
def test_chunk_note(self):
|
||||
chunks = chunk_note("Hello world")
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["text"] == "Hello world"
|
||||
assert chunks[0]["chunk_index"] == 0
|
||||
|
||||
def test_auto_title_short(self):
|
||||
assert auto_title("Short note") == "Short note"
|
||||
|
||||
def test_auto_title_long(self):
|
||||
long_text = "This is a very long note that exceeds the maximum title length and should be truncated at a word boundary"
|
||||
result = auto_title(long_text, max_len=50)
|
||||
assert len(result) <= 54 # 50 + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_auto_title_multiline(self):
|
||||
text = "First line\nSecond line\nThird line"
|
||||
assert auto_title(text) == "First line"
|
||||
|
||||
def test_auto_title_no_space(self):
|
||||
text = "a" * 100
|
||||
result = auto_title(text, max_len=80)
|
||||
assert result.endswith("...")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Tests for Docling ingestion (fixed-size chunking logic, mocked Docling)."""
|
||||
|
||||
from kb_search.ingest.docling import _fixed_chunk_text
|
||||
|
||||
|
||||
class TestFixedChunkText:
|
||||
def test_short_text_single_chunk(self):
|
||||
chunks = _fixed_chunk_text("Hello world", {})
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["text"] == "Hello world"
|
||||
assert chunks[0]["chunk_index"] == 0
|
||||
|
||||
def test_long_text_multiple_chunks(self):
|
||||
text = "word " * 2000 # ~10000 chars
|
||||
chunks = _fixed_chunk_text(text, {"max_tokens": 512, "overlap_tokens": 50})
|
||||
assert len(chunks) > 1
|
||||
# Chunks should overlap
|
||||
for i, c in enumerate(chunks):
|
||||
assert c["chunk_index"] == i
|
||||
|
||||
def test_empty_text(self):
|
||||
chunks = _fixed_chunk_text("", {})
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_whitespace_only(self):
|
||||
chunks = _fixed_chunk_text(" \n\n ", {})
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_custom_max_tokens(self):
|
||||
text = "a " * 500
|
||||
chunks = _fixed_chunk_text(text, {"max_tokens": 100})
|
||||
# 100 tokens * 4 chars = 400 chars window, 1000 chars total
|
||||
assert len(chunks) > 1
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Tests for markdown header-based splitting."""
|
||||
|
||||
from kb_search.ingest.markdown import (
|
||||
_fixed_chunk,
|
||||
_has_headers,
|
||||
_merge_small_sections,
|
||||
_split_at_headers,
|
||||
chunk_markdown,
|
||||
)
|
||||
|
||||
|
||||
def make_cfg(**overrides):
|
||||
cfg = {"chunking": {"markdown": {"strategy": "header", "min_tokens": 50, "max_tokens": 1024}}}
|
||||
cfg["chunking"]["markdown"].update(overrides)
|
||||
return cfg
|
||||
|
||||
|
||||
class TestHasHeaders:
|
||||
def test_with_headers(self):
|
||||
assert _has_headers("## Title\nContent")
|
||||
|
||||
def test_without_headers(self):
|
||||
assert not _has_headers("Just plain text\nNo headers here")
|
||||
|
||||
def test_h3(self):
|
||||
assert _has_headers("### Subsection\nStuff")
|
||||
|
||||
|
||||
class TestSplitAtHeaders:
|
||||
def test_basic_split(self):
|
||||
text = "## Section 1\nContent one\n\n## Section 2\nContent two"
|
||||
sections = _split_at_headers(text)
|
||||
assert len(sections) == 2
|
||||
assert sections[0]["header_chain"] == ["Section 1"]
|
||||
assert "Content one" in sections[0]["content"]
|
||||
assert sections[1]["header_chain"] == ["Section 2"]
|
||||
|
||||
def test_nested_headers(self):
|
||||
text = "## Config\nIntro\n\n### Advanced Options\nDetails"
|
||||
sections = _split_at_headers(text)
|
||||
assert len(sections) == 2
|
||||
# The ### should have full chain
|
||||
assert sections[1]["header_chain"] == ["Config", "Advanced Options"]
|
||||
|
||||
def test_leading_content(self):
|
||||
text = "Preamble text\n\n## First Section\nContent"
|
||||
sections = _split_at_headers(text)
|
||||
assert len(sections) == 2
|
||||
assert sections[0]["header_chain"] == []
|
||||
assert "Preamble" in sections[0]["content"]
|
||||
|
||||
def test_header_level_reset(self):
|
||||
text = "## A\n\n### B\n\n## C\n\n### D"
|
||||
sections = _split_at_headers(text)
|
||||
assert sections[2]["header_chain"] == ["C"]
|
||||
assert sections[3]["header_chain"] == ["C", "D"]
|
||||
|
||||
|
||||
class TestMergeSmallSections:
|
||||
def test_merge_tiny_into_next(self):
|
||||
sections = [
|
||||
{"header_chain": ["A"], "content": "tiny"},
|
||||
{"header_chain": ["B"], "content": "This is a much longer section with plenty of words " * 5},
|
||||
]
|
||||
merged = _merge_small_sections(sections, min_tokens=10)
|
||||
assert len(merged) == 1
|
||||
assert "tiny" in merged[0]["content"]
|
||||
|
||||
def test_no_merge_when_large_enough(self):
|
||||
sections = [
|
||||
{"header_chain": ["A"], "content": "word " * 100},
|
||||
{"header_chain": ["B"], "content": "word " * 100},
|
||||
]
|
||||
merged = _merge_small_sections(sections, min_tokens=10)
|
||||
assert len(merged) == 2
|
||||
|
||||
|
||||
class TestChunkMarkdown:
|
||||
def test_header_strategy(self):
|
||||
text = "## Intro\nSome intro text with enough words to avoid merging. " * 5
|
||||
text += "\n\n## Details\nDetailed content follows here with sufficient length. " * 5
|
||||
cfg = make_cfg(min_tokens=5)
|
||||
chunks = chunk_markdown(text, cfg)
|
||||
assert len(chunks) >= 2
|
||||
# Verify chunk_index assigned
|
||||
for i, c in enumerate(chunks):
|
||||
assert c["chunk_index"] == i
|
||||
|
||||
def test_hierarchy_context(self):
|
||||
text = "## Config\nIntro\n\n### Advanced\n" + "Details " * 60
|
||||
cfg = make_cfg(min_tokens=5)
|
||||
chunks = chunk_markdown(text, cfg)
|
||||
# Find the Advanced chunk
|
||||
advanced = [c for c in chunks if "Advanced" in c["text"]]
|
||||
assert len(advanced) > 0
|
||||
assert "Config > Advanced" in advanced[0]["text"]
|
||||
|
||||
def test_plain_text_fallback(self):
|
||||
text = "No headers here, just plain text. " * 200
|
||||
cfg = make_cfg()
|
||||
chunks = chunk_markdown(text, cfg)
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_empty_text(self):
|
||||
chunks = chunk_markdown("", make_cfg())
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
class TestFixedChunk:
|
||||
def test_basic(self):
|
||||
text = "word " * 200
|
||||
chunks = _fixed_chunk(text, {"max_tokens": 50, "overlap_tokens": 10})
|
||||
assert len(chunks) > 1
|
||||
|
||||
def test_empty(self):
|
||||
chunks = _fixed_chunk("", {})
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_short_text(self):
|
||||
chunks = _fixed_chunk("hello world", {"max_tokens": 512})
|
||||
assert len(chunks) == 1
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Tests for document management commands via Click test runner."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from kb_search.cli import main
|
||||
from kb_search.database import (
|
||||
SCHEMA_VERSION,
|
||||
get_connection,
|
||||
init_schema,
|
||||
insert_chunk,
|
||||
insert_document,
|
||||
insert_embedding,
|
||||
set_db_config,
|
||||
tag_document,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kb_env(tmp_path, monkeypatch):
|
||||
"""Set up a test KB environment."""
|
||||
data_dir = tmp_path / ".kb"
|
||||
data_dir.mkdir()
|
||||
db_path = data_dir / "kb.db"
|
||||
|
||||
conn = get_connection(db_path)
|
||||
init_schema(conn, 384)
|
||||
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
|
||||
set_db_config(conn, "model_name", "all-MiniLM-L6-v2")
|
||||
set_db_config(conn, "embedding_dim", "384")
|
||||
|
||||
# Add a test document
|
||||
doc_id = insert_document(conn, "Test Doc", "/tmp/test.pdf", "abc123", "pdf")
|
||||
insert_chunk(conn, doc_id, 0, "This is chunk zero about Python")
|
||||
insert_chunk(conn, doc_id, 1, "This is chunk one about testing")
|
||||
tag_document(conn, doc_id, ["test", "pdf"])
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
monkeypatch.setenv("KB_DATA_DIR", str(data_dir))
|
||||
return data_dir
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_json_output(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["list", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
assert len(data) == 1
|
||||
assert data[0]["title"] == "Test Doc"
|
||||
assert data[0]["type"] == "pdf"
|
||||
|
||||
def test_human_output(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["list", "--format", "human"])
|
||||
assert result.exit_code == 0
|
||||
assert "Test Doc" in result.output
|
||||
|
||||
def test_filter_type(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["list", "--type", "markdown", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
assert len(data) == 0
|
||||
|
||||
def test_filter_tags(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["list", "--tags", "test", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
assert len(data) == 1
|
||||
|
||||
|
||||
class TestInfo:
|
||||
def test_json_output(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["info", "1", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
assert data["title"] == "Test Doc"
|
||||
assert data["chunk_count"] == 2
|
||||
assert "test" in data["tags"]
|
||||
|
||||
def test_not_found(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["info", "999"])
|
||||
assert result.exit_code != 0
|
||||
assert "not found" in result.output.lower()
|
||||
|
||||
|
||||
class TestRemove:
|
||||
def test_remove_with_yes(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["remove", "1", "--yes"])
|
||||
assert result.exit_code == 0
|
||||
assert "Removed" in result.output
|
||||
|
||||
# Verify gone
|
||||
result = runner.invoke(main, ["list", "--format", "json"])
|
||||
data = json.loads(result.output)
|
||||
assert len(data) == 0
|
||||
|
||||
def test_remove_not_found(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["remove", "999", "--yes"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
class TestTags:
|
||||
def test_list_tags(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["tags", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
names = [t["name"] for t in data]
|
||||
assert "test" in names
|
||||
assert "pdf" in names
|
||||
|
||||
def test_add_tag(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["tag", "1", "--add", "new"])
|
||||
assert result.exit_code == 0
|
||||
assert "Added" in result.output
|
||||
|
||||
def test_remove_tag(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["tag", "1", "--remove", "test"])
|
||||
assert result.exit_code == 0
|
||||
assert "Removed" in result.output
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_json_output(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["status", "--format", "json"])
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
assert data["model_name"] == "all-MiniLM-L6-v2"
|
||||
assert data["total_documents"] == 1
|
||||
assert data["total_chunks"] == 2
|
||||
|
||||
def test_human_output(self, kb_env):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["status", "--format", "human"])
|
||||
assert result.exit_code == 0
|
||||
assert "all-MiniLM-L6-v2" in result.output
|
||||
|
||||
def test_not_initialised(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("KB_DATA_DIR", str(tmp_path / "nonexistent"))
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["status"])
|
||||
assert result.exit_code != 0
|
||||
assert "not initialised" in result.output.lower()
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Tests for output formatters."""
|
||||
|
||||
import json
|
||||
|
||||
from kb_search.output import (
|
||||
_human_size,
|
||||
format_doc_info,
|
||||
format_document_list,
|
||||
format_search_results,
|
||||
format_status,
|
||||
format_tags,
|
||||
)
|
||||
|
||||
SAMPLE_SEARCH = {
|
||||
"query": "install git",
|
||||
"results": [
|
||||
{
|
||||
"chunk_id": 1,
|
||||
"score": 0.031,
|
||||
"score_breakdown": {"fts": 0.016, "vector": 0.015},
|
||||
"text": "To install git from source...",
|
||||
"source": {
|
||||
"document_id": 42,
|
||||
"title": "Git Admin Guide",
|
||||
"path": "/docs/git.pdf",
|
||||
"type": "pdf",
|
||||
"page": 12,
|
||||
"section_header": None,
|
||||
"chunk_index": 3,
|
||||
"total_chunks": 28,
|
||||
"tags": ["git", "admin"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"total_matches": 47,
|
||||
"returned": 1,
|
||||
}
|
||||
|
||||
|
||||
class TestSearchOutput:
|
||||
def test_json_format(self):
|
||||
output = format_search_results(SAMPLE_SEARCH, "json")
|
||||
parsed = json.loads(output)
|
||||
assert parsed["query"] == "install git"
|
||||
assert len(parsed["results"]) == 1
|
||||
assert parsed["results"][0]["chunk_id"] == 1
|
||||
assert "fts" in parsed["results"][0]["score_breakdown"]
|
||||
assert "vector" in parsed["results"][0]["score_breakdown"]
|
||||
|
||||
def test_json_schema_fields(self):
|
||||
output = format_search_results(SAMPLE_SEARCH, "json")
|
||||
parsed = json.loads(output)
|
||||
r = parsed["results"][0]
|
||||
assert "chunk_id" in r
|
||||
assert "score" in r
|
||||
assert "text" in r
|
||||
assert "source" in r
|
||||
src = r["source"]
|
||||
assert "document_id" in src
|
||||
assert "title" in src
|
||||
assert "type" in src
|
||||
assert "tags" in src
|
||||
|
||||
def test_human_format(self):
|
||||
output = format_search_results(SAMPLE_SEARCH, "human")
|
||||
assert "install git" in output
|
||||
assert "Git Admin Guide" in output
|
||||
assert "p.12" in output
|
||||
assert "0.031" in output
|
||||
|
||||
|
||||
class TestDocList:
|
||||
def test_json(self):
|
||||
docs = [{"id": 1, "title": "Test", "type": "pdf", "tags": ["a"], "chunk_count": 5, "created_at": "2024-01-01"}]
|
||||
parsed = json.loads(format_document_list(docs, "json"))
|
||||
assert len(parsed) == 1
|
||||
|
||||
def test_human_empty(self):
|
||||
assert "No documents" in format_document_list([], "human")
|
||||
|
||||
def test_human(self):
|
||||
docs = [{"id": 1, "title": "Test", "type": "pdf", "tags": ["a"], "chunk_count": 5}]
|
||||
output = format_document_list(docs, "human")
|
||||
assert "Test" in output
|
||||
|
||||
|
||||
class TestTags:
|
||||
def test_json(self):
|
||||
tags = [{"name": "git", "count": 15}]
|
||||
parsed = json.loads(format_tags(tags, "json"))
|
||||
assert parsed[0]["name"] == "git"
|
||||
|
||||
def test_human_empty(self):
|
||||
assert "No tags" in format_tags([], "human")
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_json(self):
|
||||
status = {"model_name": "test", "embedding_dim": 384, "schema_version": 1,
|
||||
"db_size_bytes": 1024, "documents": {"pdf": 5}, "total_documents": 5, "total_chunks": 50}
|
||||
parsed = json.loads(format_status(status, "json"))
|
||||
assert parsed["model_name"] == "test"
|
||||
|
||||
def test_human(self):
|
||||
status = {"model_name": "test", "embedding_dim": 384, "schema_version": 1,
|
||||
"db_size_bytes": 1024000, "documents": {"pdf": 5}, "total_documents": 5, "total_chunks": 50}
|
||||
output = format_status(status, "human")
|
||||
assert "test" in output
|
||||
assert "384" in output
|
||||
|
||||
|
||||
class TestHumanSize:
|
||||
def test_bytes(self):
|
||||
assert _human_size(512) == "512.0 B"
|
||||
|
||||
def test_kb(self):
|
||||
assert _human_size(2048) == "2.0 KB"
|
||||
|
||||
def test_mb(self):
|
||||
assert _human_size(5 * 1024 * 1024) == "5.0 MB"
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for hybrid search, RRF merging, and filtering."""
|
||||
|
||||
import pytest
|
||||
|
||||
from kb_search.search import (
|
||||
_escape_fts_query,
|
||||
_rank_results,
|
||||
_rrf_merge,
|
||||
_single_source_results,
|
||||
)
|
||||
|
||||
|
||||
class TestEscapeFtsQuery:
|
||||
def test_plain_query(self):
|
||||
assert _escape_fts_query("install git") == "install git"
|
||||
|
||||
def test_special_chars(self):
|
||||
result = _escape_fts_query('install "git" (latest)')
|
||||
assert '"' not in result
|
||||
assert "(" not in result
|
||||
assert ")" not in result
|
||||
|
||||
def test_collapses_spaces(self):
|
||||
assert _escape_fts_query(" too many spaces ") == "too many spaces"
|
||||
|
||||
def test_empty(self):
|
||||
assert _escape_fts_query("") == ""
|
||||
|
||||
|
||||
class TestRankResults:
|
||||
def test_basic_ranking(self):
|
||||
results = {1: 0.9, 2: 0.5, 3: 0.7}
|
||||
ranked = _rank_results(results)
|
||||
assert ranked[1] == 1 # highest score = rank 1
|
||||
assert ranked[3] == 2
|
||||
assert ranked[2] == 3
|
||||
|
||||
def test_empty(self):
|
||||
assert _rank_results({}) == {}
|
||||
|
||||
|
||||
class TestRRFMerge:
|
||||
def test_basic_merge(self):
|
||||
fts = {1: 0.9, 2: 0.5}
|
||||
vec = {1: 0.8, 3: 0.7}
|
||||
merged = _rrf_merge(fts, vec, k=60)
|
||||
|
||||
scores = {r["chunk_id"]: r["score"] for r in merged}
|
||||
# Chunk 1 appears in both — should have highest score
|
||||
assert scores[1] > scores[2]
|
||||
assert scores[1] > scores[3]
|
||||
|
||||
def test_no_overlap(self):
|
||||
fts = {1: 0.9}
|
||||
vec = {2: 0.8}
|
||||
merged = _rrf_merge(fts, vec, k=60)
|
||||
assert len(merged) == 2
|
||||
|
||||
def test_score_breakdown(self):
|
||||
fts = {1: 0.9}
|
||||
vec = {1: 0.8}
|
||||
merged = _rrf_merge(fts, vec, k=60)
|
||||
assert len(merged) == 1
|
||||
assert merged[0]["score_breakdown"]["fts"] is not None
|
||||
assert merged[0]["score_breakdown"]["vector"] is not None
|
||||
|
||||
def test_single_source_fts(self):
|
||||
fts = {1: 0.9, 2: 0.5}
|
||||
merged = _rrf_merge(fts, {}, k=60)
|
||||
for r in merged:
|
||||
assert r["score_breakdown"]["vector"] is None
|
||||
assert r["score_breakdown"]["fts"] is not None
|
||||
|
||||
def test_empty_both(self):
|
||||
merged = _rrf_merge({}, {}, k=60)
|
||||
assert merged == []
|
||||
|
||||
|
||||
class TestSingleSourceResults:
|
||||
def test_fts_only(self):
|
||||
results = _single_source_results({1: 0.9, 2: 0.5}, "fts")
|
||||
assert len(results) == 2
|
||||
for r in results:
|
||||
assert r["score_breakdown"]["vector"] is None
|
||||
assert r["score_breakdown"]["fts"] is not None
|
||||
|
||||
def test_vec_only(self):
|
||||
results = _single_source_results({1: 0.8}, "vector")
|
||||
assert len(results) == 1
|
||||
assert results[0]["score_breakdown"]["fts"] is None
|
||||
assert results[0]["score_breakdown"]["vector"] is not None
|
||||
Reference in New Issue
Block a user