Initial MVP

This commit is contained in:
2026-03-23 20:38:42 +00:00
commit f245c24928
57 changed files with 6812 additions and 0 deletions
View File
+131
View File
@@ -0,0 +1,131 @@
"""Tests for configuration loading, merging, and ENV overrides."""
import os
from pathlib import Path
import pytest
import yaml
from kb_search.config import (
DEFAULTS,
_deep_merge,
_get_nested,
_set_nested,
config_with_sources,
load_config,
save_config_value,
)
def test_deep_merge_basic():
base = {"a": 1, "b": {"c": 2, "d": 3}}
override = {"b": {"c": 99}}
result = _deep_merge(base, override)
assert result == {"a": 1, "b": {"c": 99, "d": 3}}
def test_deep_merge_new_keys():
base = {"a": 1}
override = {"b": 2}
result = _deep_merge(base, override)
assert result == {"a": 1, "b": 2}
def test_deep_merge_does_not_mutate():
base = {"a": {"b": 1}}
override = {"a": {"b": 2}}
_deep_merge(base, override)
assert base["a"]["b"] == 1
def test_set_nested():
d = {}
_set_nested(d, "a.b.c", 42)
assert d == {"a": {"b": {"c": 42}}}
def test_get_nested():
d = {"a": {"b": {"c": 42}}}
assert _get_nested(d, "a.b.c") == 42
assert _get_nested(d, "a.b.x", "missing") == "missing"
assert _get_nested(d, "x.y.z") is None
def test_load_config_defaults(tmp_path):
"""With no config file, returns defaults."""
cfg = load_config(tmp_path / "nonexistent.yaml")
assert cfg["embedding"]["model"] == "all-MiniLM-L6-v2"
assert cfg["search"]["default_top"] == 10
assert cfg["chunking"]["pdf"]["strategy"] == "hierarchy"
def test_load_config_yaml_override(tmp_path):
"""YAML values override defaults."""
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.dump({"embedding": {"model": "nomic-embed-text"}}))
cfg = load_config(config_path)
assert cfg["embedding"]["model"] == "nomic-embed-text"
# Other defaults preserved
assert cfg["search"]["default_top"] == 10
def test_load_config_env_override(tmp_path, monkeypatch):
"""ENV overrides both YAML and defaults."""
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.dump({"search": {"default_top": 20}}))
monkeypatch.setenv("KB_DEFAULT_TOP", "50")
cfg = load_config(config_path)
assert cfg["search"]["default_top"] == 50
def test_load_config_env_model(tmp_path, monkeypatch):
monkeypatch.setenv("KB_MODEL", "bge-small-en-v1.5")
cfg = load_config(tmp_path / "nonexistent.yaml")
assert cfg["embedding"]["model"] == "bge-small-en-v1.5"
def test_save_config_value(tmp_path):
config_path = tmp_path / "config.yaml"
save_config_value(config_path, "chunking.pdf.max_tokens", "2048")
with open(config_path) as f:
data = yaml.safe_load(f)
assert data["chunking"]["pdf"]["max_tokens"] == 2048
def test_save_config_value_bool(tmp_path):
config_path = tmp_path / "config.yaml"
save_config_value(config_path, "chunking.code.include_context", "false")
with open(config_path) as f:
data = yaml.safe_load(f)
assert data["chunking"]["code"]["include_context"] is False
def test_save_config_preserves_existing(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.dump({"embedding": {"model": "custom"}}))
save_config_value(config_path, "search.default_top", "20")
with open(config_path) as f:
data = yaml.safe_load(f)
assert data["embedding"]["model"] == "custom"
assert data["search"]["default_top"] == 20
def test_config_with_sources_defaults(tmp_path, monkeypatch):
entries = config_with_sources(tmp_path / "nonexistent.yaml")
sources = {k: s for k, _, s in entries}
assert sources["embedding.model"] == "default"
def test_config_with_sources_yaml(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.dump({"embedding": {"model": "custom"}}))
entries = config_with_sources(config_path)
sources = {k: s for k, _, s in entries}
assert sources["embedding.model"] == "config.yaml"
def test_config_with_sources_env(tmp_path, monkeypatch):
monkeypatch.setenv("KB_MODEL", "from-env")
entries = config_with_sources(tmp_path / "nonexistent.yaml")
sources = {k: s for k, _, s in entries}
assert sources["embedding.model"] == "env (KB_MODEL)"
+206
View File
@@ -0,0 +1,206 @@
"""Tests for database schema, FTS triggers, and config helpers."""
import struct
from pathlib import Path
import pytest
from kb_search.database import (
SCHEMA_VERSION,
check_schema_version,
get_connection,
get_db_config,
get_or_create_tag,
hash_exists,
init_schema,
insert_chunk,
insert_document,
insert_embedding,
recreate_vec_table,
run_migrations,
set_db_config,
tag_document,
untag_document,
)
@pytest.fixture
def db(tmp_path):
"""Provide an initialised in-memory-like DB."""
db_path = tmp_path / "test.db"
conn = get_connection(db_path)
init_schema(conn, embedding_dim=384)
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
yield conn
conn.close()
def test_schema_creation(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "documents" in tables
assert "chunks" in tables
assert "tags" in tables
assert "document_tags" in tables
assert "config" in tables
def test_fts_table_exists(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "chunks_fts" in tables
def test_vec_table_exists(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "chunks_vec" in tables
def test_config_get_set(db):
set_db_config(db, "test_key", "test_value")
assert get_db_config(db, "test_key") == "test_value"
def test_config_get_default(db):
assert get_db_config(db, "nonexistent", "fallback") == "fallback"
def test_config_upsert(db):
set_db_config(db, "key", "v1")
set_db_config(db, "key", "v2")
assert get_db_config(db, "key") == "v2"
def test_schema_version(db):
assert check_schema_version(db) == SCHEMA_VERSION
def test_insert_document(db):
doc_id = insert_document(db, "Test Doc", "/path/test.pdf", "abc123", "pdf")
db.commit()
row = db.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
assert row["title"] == "Test Doc"
assert row["doc_type"] == "pdf"
assert row["content_hash"] == "abc123"
def test_insert_chunk_with_fts_sync(db):
doc_id = insert_document(db, "Doc", None, "hash1", "note")
chunk_id = insert_chunk(db, doc_id, 0, "This is searchable text about Python programming")
db.commit()
# FTS should find it
rows = db.execute(
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'python'"
).fetchall()
assert len(rows) == 1
assert rows[0][0] == chunk_id
def test_fts_delete_trigger(db):
doc_id = insert_document(db, "Doc", None, "hash2", "note")
chunk_id = insert_chunk(db, doc_id, 0, "unique_keyword_xyz")
db.commit()
db.execute("DELETE FROM chunks WHERE id = ?", (chunk_id,))
db.commit()
rows = db.execute(
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'unique_keyword_xyz'"
).fetchall()
assert len(rows) == 0
def test_fts_update_trigger(db):
doc_id = insert_document(db, "Doc", None, "hash3", "note")
chunk_id = insert_chunk(db, doc_id, 0, "old_content_abc")
db.commit()
db.execute("UPDATE chunks SET text = 'new_content_def' WHERE id = ?", (chunk_id,))
db.commit()
old = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'old_content_abc'").fetchall()
new = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'new_content_def'").fetchall()
assert len(old) == 0
assert len(new) == 1
def test_insert_embedding(db):
doc_id = insert_document(db, "Doc", None, "hash4", "note")
chunk_id = insert_chunk(db, doc_id, 0, "text")
db.commit()
embedding = [0.1] * 384
insert_embedding(db, chunk_id, embedding)
db.commit()
row = db.execute("SELECT * FROM chunks_vec WHERE chunk_id = ?", (chunk_id,)).fetchone()
assert row is not None
def test_hash_exists(db):
assert not hash_exists(db, "newhash")
insert_document(db, "Doc", None, "newhash", "note")
db.commit()
assert hash_exists(db, "newhash")
def test_tag_management(db):
doc_id = insert_document(db, "Doc", None, "hash5", "pdf")
db.commit()
tag_document(db, doc_id, ["git", "admin"])
db.commit()
rows = db.execute(
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
"WHERE dt.document_id = ? ORDER BY t.name",
(doc_id,),
).fetchall()
assert [r["name"] for r in rows] == ["admin", "git"]
def test_untag_document(db):
doc_id = insert_document(db, "Doc", None, "hash6", "pdf")
tag_document(db, doc_id, ["a", "b", "c"])
db.commit()
untag_document(db, doc_id, ["b"])
db.commit()
rows = db.execute(
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
"WHERE dt.document_id = ? ORDER BY t.name",
(doc_id,),
).fetchall()
assert [r["name"] for r in rows] == ["a", "c"]
def test_tags_are_lowercase(db):
tag_id = get_or_create_tag(db, "MyTag")
db.commit()
row = db.execute("SELECT name FROM tags WHERE id = ?", (tag_id,)).fetchone()
assert row["name"] == "mytag"
def test_recreate_vec_table(db):
doc_id = insert_document(db, "Doc", None, "hash7", "note")
chunk_id = insert_chunk(db, doc_id, 0, "text")
insert_embedding(db, chunk_id, [0.1] * 384)
db.commit()
recreate_vec_table(db, 768)
# Old data gone, new dimension
rows = db.execute("SELECT * FROM chunks_vec").fetchall()
assert len(rows) == 0
def test_cascade_delete(db):
doc_id = insert_document(db, "Doc", None, "hash8", "pdf")
insert_chunk(db, doc_id, 0, "chunk text")
tag_document(db, doc_id, ["test"])
db.commit()
db.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
db.commit()
assert db.execute("SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
assert db.execute("SELECT COUNT(*) FROM document_tags WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
+50
View File
@@ -0,0 +1,50 @@
"""Tests for embedding model management."""
from unittest.mock import MagicMock, patch
import click
import pytest
from kb_search.embeddings import check_model_binding
@pytest.fixture
def mock_conn():
"""Mock DB connection with config values."""
def make_conn(config_values=None):
config_values = config_values or {}
conn = MagicMock()
def mock_execute(sql, params=None):
if "SELECT value FROM config" in sql and params:
key = params[0]
val = config_values.get(key)
row = MagicMock()
row.__getitem__ = lambda self, k: val
result = MagicMock()
result.fetchone.return_value = row if val else None
return result
return MagicMock()
conn.execute = mock_execute
return conn
return make_conn
def test_model_binding_match(mock_conn):
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
cfg = {"embedding": {"model": "all-MiniLM-L6-v2"}}
# Should not raise
check_model_binding(conn, cfg)
def test_model_binding_mismatch(mock_conn):
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
cfg = {"embedding": {"model": "nomic-embed-text"}}
with pytest.raises(click.ClickException, match="Model mismatch"):
check_model_binding(conn, cfg)
def test_model_binding_no_db_model(mock_conn):
conn = mock_conn({})
cfg = {"embedding": {"model": "anything"}}
# Should not raise when DB not yet initialised
check_model_binding(conn, cfg)
+172
View File
@@ -0,0 +1,172 @@
"""Tests for code chunking — Python, Bash, Go."""
from kb_search.ingest.code import chunk_code, _chunk_python, _chunk_bash, _chunk_go, _fixed_chunk
CFG = {"chunking": {"code": {"strategy": "ast", "include_context": True, "max_tokens": 1024}}}
class TestPythonChunking:
def test_functions(self):
code = '''
def hello():
"""Say hello."""
print("hello")
def goodbye():
"""Say goodbye."""
print("bye")
'''
chunks = _chunk_python(code, include_context=True)
assert len(chunks) == 2
assert chunks[0]["metadata"]["symbol_name"] == "hello"
assert chunks[1]["metadata"]["symbol_name"] == "goodbye"
def test_class_with_methods(self):
code = '''
class MyClass:
"""A test class."""
def method_a(self):
pass
def method_b(self):
pass
'''
chunks = _chunk_python(code, include_context=True)
assert len(chunks) == 2
assert chunks[0]["metadata"]["symbol_name"] == "MyClass.method_a"
assert chunks[1]["metadata"]["symbol_name"] == "MyClass.method_b"
# Context should include class docstring
assert "A test class" in chunks[0]["text"]
def test_class_without_methods(self):
code = '''
class Config:
"""Configuration."""
DEBUG = True
PORT = 8080
'''
chunks = _chunk_python(code, include_context=True)
assert len(chunks) == 1
assert chunks[0]["metadata"]["symbol_name"] == "Config"
def test_syntax_error_returns_empty(self):
chunks = _chunk_python("def broken(:\n pass", include_context=True)
assert chunks == []
def test_no_context(self):
code = '''
class Foo:
"""Docstring."""
def bar(self):
pass
'''
chunks = _chunk_python(code, include_context=False)
assert len(chunks) == 1
assert "Docstring" not in chunks[0]["text"]
class TestBashChunking:
def test_function_keyword(self):
code = '''#!/bin/bash
function deploy() {
echo "deploying"
}
function rollback() {
echo "rolling back"
}
'''
chunks = _chunk_bash(code, include_context=True)
assert len(chunks) == 2
assert chunks[0]["metadata"]["symbol_name"] == "deploy"
assert chunks[1]["metadata"]["symbol_name"] == "rollback"
def test_shorthand_syntax(self):
code = '''
setup() {
echo "setup"
}
cleanup() {
echo "cleanup"
}
'''
chunks = _chunk_bash(code, include_context=True)
assert len(chunks) == 2
def test_no_functions(self):
code = "#!/bin/bash\necho hello\nexit 0"
chunks = _chunk_bash(code, include_context=True)
assert chunks == []
def test_with_preceding_comments(self):
code = '''
# Deploy to production
# Requires valid credentials
function deploy() {
echo "deploying"
}
'''
chunks = _chunk_bash(code, include_context=True)
assert len(chunks) == 1
assert "Deploy to production" in chunks[0]["text"]
class TestGoChunking:
def test_basic_funcs(self):
code = '''package main
func main() {
fmt.Println("hello")
}
func helper() string {
return "help"
}
'''
chunks = _chunk_go(code, include_context=True)
assert len(chunks) == 2
assert chunks[0]["metadata"]["symbol_name"] == "main"
assert chunks[1]["metadata"]["symbol_name"] == "helper"
def test_method_receiver(self):
code = '''
func (s *Server) Start() error {
return nil
}
func (s *Server) Stop() {
}
'''
chunks = _chunk_go(code, include_context=True)
assert len(chunks) == 2
assert chunks[0]["metadata"]["symbol_name"] == "Start"
def test_no_funcs(self):
code = "package main\n\nvar x = 1"
chunks = _chunk_go(code, include_context=True)
assert chunks == []
class TestFallback:
def test_unknown_language_uses_fixed(self):
code = "line1\nline2\nline3"
chunks = chunk_code(code, "ruby", CFG)
assert len(chunks) >= 1
def test_python_no_functions_uses_fixed(self):
code = "x = 1\ny = 2\nprint(x + y)"
chunks = chunk_code(code, "python", CFG)
assert len(chunks) >= 1
def test_fixed_strategy_config(self):
cfg = {"chunking": {"code": {"strategy": "fixed", "max_tokens": 10}}}
code = "\n".join(f"x_{i} = {i}" for i in range(50))
chunks = chunk_code(code, "python", cfg)
assert len(chunks) > 1
def test_empty_code(self):
chunks = chunk_code("", "python", CFG)
assert len(chunks) == 0
+81
View File
@@ -0,0 +1,81 @@
"""Tests for file type detection, dedup, note creation."""
from pathlib import Path
import pytest
from kb_search.ingest.detector import detect_type, is_supported
from kb_search.ingest.note import auto_title, chunk_note
class TestDetector:
def test_pdf(self, tmp_path):
assert detect_type(tmp_path / "doc.pdf") == ("pdf", None)
def test_markdown(self, tmp_path):
assert detect_type(tmp_path / "notes.md") == ("markdown", None)
def test_txt(self, tmp_path):
assert detect_type(tmp_path / "notes.txt") == ("markdown", None)
def test_python(self, tmp_path):
assert detect_type(tmp_path / "main.py") == ("code", "python")
def test_bash(self, tmp_path):
assert detect_type(tmp_path / "deploy.sh") == ("code", "bash")
def test_go(self, tmp_path):
assert detect_type(tmp_path / "main.go") == ("code", "go")
def test_unsupported(self, tmp_path):
with pytest.raises(ValueError, match="Unsupported"):
detect_type(tmp_path / "archive.zip")
def test_force_type(self, tmp_path):
assert detect_type(tmp_path / "data.txt", force_type="code", force_language="bash") == ("code", "bash")
def test_force_language_only(self, tmp_path):
doc_type, lang = detect_type(tmp_path / "script.py", force_language="go")
assert doc_type == "code"
assert lang == "go"
def test_is_supported(self, tmp_path):
assert is_supported(tmp_path / "test.pdf")
assert is_supported(tmp_path / "test.py")
assert not is_supported(tmp_path / "test.zip")
def test_case_insensitive(self, tmp_path):
assert detect_type(tmp_path / "DOC.PDF") == ("pdf", None)
def test_image_files(self, tmp_path):
assert detect_type(tmp_path / "scan.png") == ("pdf", None)
assert detect_type(tmp_path / "photo.jpg") == ("pdf", None)
def test_docx(self, tmp_path):
assert detect_type(tmp_path / "report.docx") == ("pdf", None)
class TestNote:
def test_chunk_note(self):
chunks = chunk_note("Hello world")
assert len(chunks) == 1
assert chunks[0]["text"] == "Hello world"
assert chunks[0]["chunk_index"] == 0
def test_auto_title_short(self):
assert auto_title("Short note") == "Short note"
def test_auto_title_long(self):
long_text = "This is a very long note that exceeds the maximum title length and should be truncated at a word boundary"
result = auto_title(long_text, max_len=50)
assert len(result) <= 54 # 50 + "..."
assert result.endswith("...")
def test_auto_title_multiline(self):
text = "First line\nSecond line\nThird line"
assert auto_title(text) == "First line"
def test_auto_title_no_space(self):
text = "a" * 100
result = auto_title(text, max_len=80)
assert result.endswith("...")
+33
View File
@@ -0,0 +1,33 @@
"""Tests for Docling ingestion (fixed-size chunking logic, mocked Docling)."""
from kb_search.ingest.docling import _fixed_chunk_text
class TestFixedChunkText:
def test_short_text_single_chunk(self):
chunks = _fixed_chunk_text("Hello world", {})
assert len(chunks) == 1
assert chunks[0]["text"] == "Hello world"
assert chunks[0]["chunk_index"] == 0
def test_long_text_multiple_chunks(self):
text = "word " * 2000 # ~10000 chars
chunks = _fixed_chunk_text(text, {"max_tokens": 512, "overlap_tokens": 50})
assert len(chunks) > 1
# Chunks should overlap
for i, c in enumerate(chunks):
assert c["chunk_index"] == i
def test_empty_text(self):
chunks = _fixed_chunk_text("", {})
assert len(chunks) == 0
def test_whitespace_only(self):
chunks = _fixed_chunk_text(" \n\n ", {})
assert len(chunks) == 0
def test_custom_max_tokens(self):
text = "a " * 500
chunks = _fixed_chunk_text(text, {"max_tokens": 100})
# 100 tokens * 4 chars = 400 chars window, 1000 chars total
assert len(chunks) > 1
+121
View File
@@ -0,0 +1,121 @@
"""Tests for markdown header-based splitting."""
from kb_search.ingest.markdown import (
_fixed_chunk,
_has_headers,
_merge_small_sections,
_split_at_headers,
chunk_markdown,
)
def make_cfg(**overrides):
cfg = {"chunking": {"markdown": {"strategy": "header", "min_tokens": 50, "max_tokens": 1024}}}
cfg["chunking"]["markdown"].update(overrides)
return cfg
class TestHasHeaders:
def test_with_headers(self):
assert _has_headers("## Title\nContent")
def test_without_headers(self):
assert not _has_headers("Just plain text\nNo headers here")
def test_h3(self):
assert _has_headers("### Subsection\nStuff")
class TestSplitAtHeaders:
def test_basic_split(self):
text = "## Section 1\nContent one\n\n## Section 2\nContent two"
sections = _split_at_headers(text)
assert len(sections) == 2
assert sections[0]["header_chain"] == ["Section 1"]
assert "Content one" in sections[0]["content"]
assert sections[1]["header_chain"] == ["Section 2"]
def test_nested_headers(self):
text = "## Config\nIntro\n\n### Advanced Options\nDetails"
sections = _split_at_headers(text)
assert len(sections) == 2
# The ### should have full chain
assert sections[1]["header_chain"] == ["Config", "Advanced Options"]
def test_leading_content(self):
text = "Preamble text\n\n## First Section\nContent"
sections = _split_at_headers(text)
assert len(sections) == 2
assert sections[0]["header_chain"] == []
assert "Preamble" in sections[0]["content"]
def test_header_level_reset(self):
text = "## A\n\n### B\n\n## C\n\n### D"
sections = _split_at_headers(text)
assert sections[2]["header_chain"] == ["C"]
assert sections[3]["header_chain"] == ["C", "D"]
class TestMergeSmallSections:
def test_merge_tiny_into_next(self):
sections = [
{"header_chain": ["A"], "content": "tiny"},
{"header_chain": ["B"], "content": "This is a much longer section with plenty of words " * 5},
]
merged = _merge_small_sections(sections, min_tokens=10)
assert len(merged) == 1
assert "tiny" in merged[0]["content"]
def test_no_merge_when_large_enough(self):
sections = [
{"header_chain": ["A"], "content": "word " * 100},
{"header_chain": ["B"], "content": "word " * 100},
]
merged = _merge_small_sections(sections, min_tokens=10)
assert len(merged) == 2
class TestChunkMarkdown:
def test_header_strategy(self):
text = "## Intro\nSome intro text with enough words to avoid merging. " * 5
text += "\n\n## Details\nDetailed content follows here with sufficient length. " * 5
cfg = make_cfg(min_tokens=5)
chunks = chunk_markdown(text, cfg)
assert len(chunks) >= 2
# Verify chunk_index assigned
for i, c in enumerate(chunks):
assert c["chunk_index"] == i
def test_hierarchy_context(self):
text = "## Config\nIntro\n\n### Advanced\n" + "Details " * 60
cfg = make_cfg(min_tokens=5)
chunks = chunk_markdown(text, cfg)
# Find the Advanced chunk
advanced = [c for c in chunks if "Advanced" in c["text"]]
assert len(advanced) > 0
assert "Config > Advanced" in advanced[0]["text"]
def test_plain_text_fallback(self):
text = "No headers here, just plain text. " * 200
cfg = make_cfg()
chunks = chunk_markdown(text, cfg)
assert len(chunks) >= 1
def test_empty_text(self):
chunks = chunk_markdown("", make_cfg())
assert len(chunks) == 0
class TestFixedChunk:
def test_basic(self):
text = "word " * 200
chunks = _fixed_chunk(text, {"max_tokens": 50, "overlap_tokens": 10})
assert len(chunks) > 1
def test_empty(self):
chunks = _fixed_chunk("", {})
assert len(chunks) == 0
def test_short_text(self):
chunks = _fixed_chunk("hello world", {"max_tokens": 512})
assert len(chunks) == 1
+156
View File
@@ -0,0 +1,156 @@
"""Tests for document management commands via Click test runner."""
import json
import pytest
from click.testing import CliRunner
from kb_search.cli import main
from kb_search.database import (
SCHEMA_VERSION,
get_connection,
init_schema,
insert_chunk,
insert_document,
insert_embedding,
set_db_config,
tag_document,
)
@pytest.fixture
def kb_env(tmp_path, monkeypatch):
"""Set up a test KB environment."""
data_dir = tmp_path / ".kb"
data_dir.mkdir()
db_path = data_dir / "kb.db"
conn = get_connection(db_path)
init_schema(conn, 384)
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
set_db_config(conn, "model_name", "all-MiniLM-L6-v2")
set_db_config(conn, "embedding_dim", "384")
# Add a test document
doc_id = insert_document(conn, "Test Doc", "/tmp/test.pdf", "abc123", "pdf")
insert_chunk(conn, doc_id, 0, "This is chunk zero about Python")
insert_chunk(conn, doc_id, 1, "This is chunk one about testing")
tag_document(conn, doc_id, ["test", "pdf"])
conn.commit()
conn.close()
monkeypatch.setenv("KB_DATA_DIR", str(data_dir))
return data_dir
class TestList:
def test_json_output(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["list", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
assert len(data) == 1
assert data[0]["title"] == "Test Doc"
assert data[0]["type"] == "pdf"
def test_human_output(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["list", "--format", "human"])
assert result.exit_code == 0
assert "Test Doc" in result.output
def test_filter_type(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["list", "--type", "markdown", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
assert len(data) == 0
def test_filter_tags(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["list", "--tags", "test", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
assert len(data) == 1
class TestInfo:
def test_json_output(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["info", "1", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
assert data["title"] == "Test Doc"
assert data["chunk_count"] == 2
assert "test" in data["tags"]
def test_not_found(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["info", "999"])
assert result.exit_code != 0
assert "not found" in result.output.lower()
class TestRemove:
def test_remove_with_yes(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["remove", "1", "--yes"])
assert result.exit_code == 0
assert "Removed" in result.output
# Verify gone
result = runner.invoke(main, ["list", "--format", "json"])
data = json.loads(result.output)
assert len(data) == 0
def test_remove_not_found(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["remove", "999", "--yes"])
assert result.exit_code != 0
class TestTags:
def test_list_tags(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["tags", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
names = [t["name"] for t in data]
assert "test" in names
assert "pdf" in names
def test_add_tag(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["tag", "1", "--add", "new"])
assert result.exit_code == 0
assert "Added" in result.output
def test_remove_tag(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["tag", "1", "--remove", "test"])
assert result.exit_code == 0
assert "Removed" in result.output
class TestStatus:
def test_json_output(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["status", "--format", "json"])
assert result.exit_code == 0
data = json.loads(result.output)
assert data["model_name"] == "all-MiniLM-L6-v2"
assert data["total_documents"] == 1
assert data["total_chunks"] == 2
def test_human_output(self, kb_env):
runner = CliRunner()
result = runner.invoke(main, ["status", "--format", "human"])
assert result.exit_code == 0
assert "all-MiniLM-L6-v2" in result.output
def test_not_initialised(self, tmp_path, monkeypatch):
monkeypatch.setenv("KB_DATA_DIR", str(tmp_path / "nonexistent"))
runner = CliRunner()
result = runner.invoke(main, ["status"])
assert result.exit_code != 0
assert "not initialised" in result.output.lower()
+120
View File
@@ -0,0 +1,120 @@
"""Tests for output formatters."""
import json
from kb_search.output import (
_human_size,
format_doc_info,
format_document_list,
format_search_results,
format_status,
format_tags,
)
SAMPLE_SEARCH = {
"query": "install git",
"results": [
{
"chunk_id": 1,
"score": 0.031,
"score_breakdown": {"fts": 0.016, "vector": 0.015},
"text": "To install git from source...",
"source": {
"document_id": 42,
"title": "Git Admin Guide",
"path": "/docs/git.pdf",
"type": "pdf",
"page": 12,
"section_header": None,
"chunk_index": 3,
"total_chunks": 28,
"tags": ["git", "admin"],
},
}
],
"total_matches": 47,
"returned": 1,
}
class TestSearchOutput:
def test_json_format(self):
output = format_search_results(SAMPLE_SEARCH, "json")
parsed = json.loads(output)
assert parsed["query"] == "install git"
assert len(parsed["results"]) == 1
assert parsed["results"][0]["chunk_id"] == 1
assert "fts" in parsed["results"][0]["score_breakdown"]
assert "vector" in parsed["results"][0]["score_breakdown"]
def test_json_schema_fields(self):
output = format_search_results(SAMPLE_SEARCH, "json")
parsed = json.loads(output)
r = parsed["results"][0]
assert "chunk_id" in r
assert "score" in r
assert "text" in r
assert "source" in r
src = r["source"]
assert "document_id" in src
assert "title" in src
assert "type" in src
assert "tags" in src
def test_human_format(self):
output = format_search_results(SAMPLE_SEARCH, "human")
assert "install git" in output
assert "Git Admin Guide" in output
assert "p.12" in output
assert "0.031" in output
class TestDocList:
def test_json(self):
docs = [{"id": 1, "title": "Test", "type": "pdf", "tags": ["a"], "chunk_count": 5, "created_at": "2024-01-01"}]
parsed = json.loads(format_document_list(docs, "json"))
assert len(parsed) == 1
def test_human_empty(self):
assert "No documents" in format_document_list([], "human")
def test_human(self):
docs = [{"id": 1, "title": "Test", "type": "pdf", "tags": ["a"], "chunk_count": 5}]
output = format_document_list(docs, "human")
assert "Test" in output
class TestTags:
def test_json(self):
tags = [{"name": "git", "count": 15}]
parsed = json.loads(format_tags(tags, "json"))
assert parsed[0]["name"] == "git"
def test_human_empty(self):
assert "No tags" in format_tags([], "human")
class TestStatus:
def test_json(self):
status = {"model_name": "test", "embedding_dim": 384, "schema_version": 1,
"db_size_bytes": 1024, "documents": {"pdf": 5}, "total_documents": 5, "total_chunks": 50}
parsed = json.loads(format_status(status, "json"))
assert parsed["model_name"] == "test"
def test_human(self):
status = {"model_name": "test", "embedding_dim": 384, "schema_version": 1,
"db_size_bytes": 1024000, "documents": {"pdf": 5}, "total_documents": 5, "total_chunks": 50}
output = format_status(status, "human")
assert "test" in output
assert "384" in output
class TestHumanSize:
def test_bytes(self):
assert _human_size(512) == "512.0 B"
def test_kb(self):
assert _human_size(2048) == "2.0 KB"
def test_mb(self):
assert _human_size(5 * 1024 * 1024) == "5.0 MB"
+91
View File
@@ -0,0 +1,91 @@
"""Tests for hybrid search, RRF merging, and filtering."""
import pytest
from kb_search.search import (
_escape_fts_query,
_rank_results,
_rrf_merge,
_single_source_results,
)
class TestEscapeFtsQuery:
def test_plain_query(self):
assert _escape_fts_query("install git") == "install git"
def test_special_chars(self):
result = _escape_fts_query('install "git" (latest)')
assert '"' not in result
assert "(" not in result
assert ")" not in result
def test_collapses_spaces(self):
assert _escape_fts_query(" too many spaces ") == "too many spaces"
def test_empty(self):
assert _escape_fts_query("") == ""
class TestRankResults:
def test_basic_ranking(self):
results = {1: 0.9, 2: 0.5, 3: 0.7}
ranked = _rank_results(results)
assert ranked[1] == 1 # highest score = rank 1
assert ranked[3] == 2
assert ranked[2] == 3
def test_empty(self):
assert _rank_results({}) == {}
class TestRRFMerge:
def test_basic_merge(self):
fts = {1: 0.9, 2: 0.5}
vec = {1: 0.8, 3: 0.7}
merged = _rrf_merge(fts, vec, k=60)
scores = {r["chunk_id"]: r["score"] for r in merged}
# Chunk 1 appears in both — should have highest score
assert scores[1] > scores[2]
assert scores[1] > scores[3]
def test_no_overlap(self):
fts = {1: 0.9}
vec = {2: 0.8}
merged = _rrf_merge(fts, vec, k=60)
assert len(merged) == 2
def test_score_breakdown(self):
fts = {1: 0.9}
vec = {1: 0.8}
merged = _rrf_merge(fts, vec, k=60)
assert len(merged) == 1
assert merged[0]["score_breakdown"]["fts"] is not None
assert merged[0]["score_breakdown"]["vector"] is not None
def test_single_source_fts(self):
fts = {1: 0.9, 2: 0.5}
merged = _rrf_merge(fts, {}, k=60)
for r in merged:
assert r["score_breakdown"]["vector"] is None
assert r["score_breakdown"]["fts"] is not None
def test_empty_both(self):
merged = _rrf_merge({}, {}, k=60)
assert merged == []
class TestSingleSourceResults:
def test_fts_only(self):
results = _single_source_results({1: 0.9, 2: 0.5}, "fts")
assert len(results) == 2
for r in results:
assert r["score_breakdown"]["vector"] is None
assert r["score_breakdown"]["fts"] is not None
def test_vec_only(self):
results = _single_source_results({1: 0.8}, "vector")
assert len(results) == 1
assert results[0]["score_breakdown"]["fts"] is None
assert results[0]["score_breakdown"]["vector"] is not None