Files
kb/tests/test_database.py
T
2026-03-23 20:38:42 +00:00

207 lines
5.9 KiB
Python

"""Tests for database schema, FTS triggers, and config helpers."""
import struct
from pathlib import Path
import pytest
from kb_search.database import (
SCHEMA_VERSION,
check_schema_version,
get_connection,
get_db_config,
get_or_create_tag,
hash_exists,
init_schema,
insert_chunk,
insert_document,
insert_embedding,
recreate_vec_table,
run_migrations,
set_db_config,
tag_document,
untag_document,
)
@pytest.fixture
def db(tmp_path):
"""Provide an initialised in-memory-like DB."""
db_path = tmp_path / "test.db"
conn = get_connection(db_path)
init_schema(conn, embedding_dim=384)
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
yield conn
conn.close()
def test_schema_creation(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "documents" in tables
assert "chunks" in tables
assert "tags" in tables
assert "document_tags" in tables
assert "config" in tables
def test_fts_table_exists(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "chunks_fts" in tables
def test_vec_table_exists(db):
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
assert "chunks_vec" in tables
def test_config_get_set(db):
set_db_config(db, "test_key", "test_value")
assert get_db_config(db, "test_key") == "test_value"
def test_config_get_default(db):
assert get_db_config(db, "nonexistent", "fallback") == "fallback"
def test_config_upsert(db):
set_db_config(db, "key", "v1")
set_db_config(db, "key", "v2")
assert get_db_config(db, "key") == "v2"
def test_schema_version(db):
assert check_schema_version(db) == SCHEMA_VERSION
def test_insert_document(db):
doc_id = insert_document(db, "Test Doc", "/path/test.pdf", "abc123", "pdf")
db.commit()
row = db.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
assert row["title"] == "Test Doc"
assert row["doc_type"] == "pdf"
assert row["content_hash"] == "abc123"
def test_insert_chunk_with_fts_sync(db):
doc_id = insert_document(db, "Doc", None, "hash1", "note")
chunk_id = insert_chunk(db, doc_id, 0, "This is searchable text about Python programming")
db.commit()
# FTS should find it
rows = db.execute(
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'python'"
).fetchall()
assert len(rows) == 1
assert rows[0][0] == chunk_id
def test_fts_delete_trigger(db):
doc_id = insert_document(db, "Doc", None, "hash2", "note")
chunk_id = insert_chunk(db, doc_id, 0, "unique_keyword_xyz")
db.commit()
db.execute("DELETE FROM chunks WHERE id = ?", (chunk_id,))
db.commit()
rows = db.execute(
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'unique_keyword_xyz'"
).fetchall()
assert len(rows) == 0
def test_fts_update_trigger(db):
doc_id = insert_document(db, "Doc", None, "hash3", "note")
chunk_id = insert_chunk(db, doc_id, 0, "old_content_abc")
db.commit()
db.execute("UPDATE chunks SET text = 'new_content_def' WHERE id = ?", (chunk_id,))
db.commit()
old = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'old_content_abc'").fetchall()
new = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'new_content_def'").fetchall()
assert len(old) == 0
assert len(new) == 1
def test_insert_embedding(db):
doc_id = insert_document(db, "Doc", None, "hash4", "note")
chunk_id = insert_chunk(db, doc_id, 0, "text")
db.commit()
embedding = [0.1] * 384
insert_embedding(db, chunk_id, embedding)
db.commit()
row = db.execute("SELECT * FROM chunks_vec WHERE chunk_id = ?", (chunk_id,)).fetchone()
assert row is not None
def test_hash_exists(db):
assert not hash_exists(db, "newhash")
insert_document(db, "Doc", None, "newhash", "note")
db.commit()
assert hash_exists(db, "newhash")
def test_tag_management(db):
doc_id = insert_document(db, "Doc", None, "hash5", "pdf")
db.commit()
tag_document(db, doc_id, ["git", "admin"])
db.commit()
rows = db.execute(
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
"WHERE dt.document_id = ? ORDER BY t.name",
(doc_id,),
).fetchall()
assert [r["name"] for r in rows] == ["admin", "git"]
def test_untag_document(db):
doc_id = insert_document(db, "Doc", None, "hash6", "pdf")
tag_document(db, doc_id, ["a", "b", "c"])
db.commit()
untag_document(db, doc_id, ["b"])
db.commit()
rows = db.execute(
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
"WHERE dt.document_id = ? ORDER BY t.name",
(doc_id,),
).fetchall()
assert [r["name"] for r in rows] == ["a", "c"]
def test_tags_are_lowercase(db):
tag_id = get_or_create_tag(db, "MyTag")
db.commit()
row = db.execute("SELECT name FROM tags WHERE id = ?", (tag_id,)).fetchone()
assert row["name"] == "mytag"
def test_recreate_vec_table(db):
doc_id = insert_document(db, "Doc", None, "hash7", "note")
chunk_id = insert_chunk(db, doc_id, 0, "text")
insert_embedding(db, chunk_id, [0.1] * 384)
db.commit()
recreate_vec_table(db, 768)
# Old data gone, new dimension
rows = db.execute("SELECT * FROM chunks_vec").fetchall()
assert len(rows) == 0
def test_cascade_delete(db):
doc_id = insert_document(db, "Doc", None, "hash8", "pdf")
insert_chunk(db, doc_id, 0, "chunk text")
tag_document(db, doc_id, ["test"])
db.commit()
db.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
db.commit()
assert db.execute("SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
assert db.execute("SELECT COUNT(*) FROM document_tags WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0