207 lines
5.9 KiB
Python
207 lines
5.9 KiB
Python
"""Tests for database schema, FTS triggers, and config helpers."""
|
|
|
|
import struct
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from kb_search.database import (
|
|
SCHEMA_VERSION,
|
|
check_schema_version,
|
|
get_connection,
|
|
get_db_config,
|
|
get_or_create_tag,
|
|
hash_exists,
|
|
init_schema,
|
|
insert_chunk,
|
|
insert_document,
|
|
insert_embedding,
|
|
recreate_vec_table,
|
|
run_migrations,
|
|
set_db_config,
|
|
tag_document,
|
|
untag_document,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def db(tmp_path):
|
|
"""Provide an initialised in-memory-like DB."""
|
|
db_path = tmp_path / "test.db"
|
|
conn = get_connection(db_path)
|
|
init_schema(conn, embedding_dim=384)
|
|
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
|
|
yield conn
|
|
conn.close()
|
|
|
|
|
|
def test_schema_creation(db):
|
|
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
|
assert "documents" in tables
|
|
assert "chunks" in tables
|
|
assert "tags" in tables
|
|
assert "document_tags" in tables
|
|
assert "config" in tables
|
|
|
|
|
|
def test_fts_table_exists(db):
|
|
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
|
assert "chunks_fts" in tables
|
|
|
|
|
|
def test_vec_table_exists(db):
|
|
tables = [r[0] for r in db.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
|
assert "chunks_vec" in tables
|
|
|
|
|
|
def test_config_get_set(db):
|
|
set_db_config(db, "test_key", "test_value")
|
|
assert get_db_config(db, "test_key") == "test_value"
|
|
|
|
|
|
def test_config_get_default(db):
|
|
assert get_db_config(db, "nonexistent", "fallback") == "fallback"
|
|
|
|
|
|
def test_config_upsert(db):
|
|
set_db_config(db, "key", "v1")
|
|
set_db_config(db, "key", "v2")
|
|
assert get_db_config(db, "key") == "v2"
|
|
|
|
|
|
def test_schema_version(db):
|
|
assert check_schema_version(db) == SCHEMA_VERSION
|
|
|
|
|
|
def test_insert_document(db):
|
|
doc_id = insert_document(db, "Test Doc", "/path/test.pdf", "abc123", "pdf")
|
|
db.commit()
|
|
row = db.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
|
|
assert row["title"] == "Test Doc"
|
|
assert row["doc_type"] == "pdf"
|
|
assert row["content_hash"] == "abc123"
|
|
|
|
|
|
def test_insert_chunk_with_fts_sync(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash1", "note")
|
|
chunk_id = insert_chunk(db, doc_id, 0, "This is searchable text about Python programming")
|
|
db.commit()
|
|
|
|
# FTS should find it
|
|
rows = db.execute(
|
|
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'python'"
|
|
).fetchall()
|
|
assert len(rows) == 1
|
|
assert rows[0][0] == chunk_id
|
|
|
|
|
|
def test_fts_delete_trigger(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash2", "note")
|
|
chunk_id = insert_chunk(db, doc_id, 0, "unique_keyword_xyz")
|
|
db.commit()
|
|
|
|
db.execute("DELETE FROM chunks WHERE id = ?", (chunk_id,))
|
|
db.commit()
|
|
|
|
rows = db.execute(
|
|
"SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'unique_keyword_xyz'"
|
|
).fetchall()
|
|
assert len(rows) == 0
|
|
|
|
|
|
def test_fts_update_trigger(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash3", "note")
|
|
chunk_id = insert_chunk(db, doc_id, 0, "old_content_abc")
|
|
db.commit()
|
|
|
|
db.execute("UPDATE chunks SET text = 'new_content_def' WHERE id = ?", (chunk_id,))
|
|
db.commit()
|
|
|
|
old = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'old_content_abc'").fetchall()
|
|
new = db.execute("SELECT rowid FROM chunks_fts WHERE chunks_fts MATCH 'new_content_def'").fetchall()
|
|
assert len(old) == 0
|
|
assert len(new) == 1
|
|
|
|
|
|
def test_insert_embedding(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash4", "note")
|
|
chunk_id = insert_chunk(db, doc_id, 0, "text")
|
|
db.commit()
|
|
|
|
embedding = [0.1] * 384
|
|
insert_embedding(db, chunk_id, embedding)
|
|
db.commit()
|
|
|
|
row = db.execute("SELECT * FROM chunks_vec WHERE chunk_id = ?", (chunk_id,)).fetchone()
|
|
assert row is not None
|
|
|
|
|
|
def test_hash_exists(db):
|
|
assert not hash_exists(db, "newhash")
|
|
insert_document(db, "Doc", None, "newhash", "note")
|
|
db.commit()
|
|
assert hash_exists(db, "newhash")
|
|
|
|
|
|
def test_tag_management(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash5", "pdf")
|
|
db.commit()
|
|
|
|
tag_document(db, doc_id, ["git", "admin"])
|
|
db.commit()
|
|
|
|
rows = db.execute(
|
|
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
|
|
"WHERE dt.document_id = ? ORDER BY t.name",
|
|
(doc_id,),
|
|
).fetchall()
|
|
assert [r["name"] for r in rows] == ["admin", "git"]
|
|
|
|
|
|
def test_untag_document(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash6", "pdf")
|
|
tag_document(db, doc_id, ["a", "b", "c"])
|
|
db.commit()
|
|
|
|
untag_document(db, doc_id, ["b"])
|
|
db.commit()
|
|
|
|
rows = db.execute(
|
|
"SELECT t.name FROM tags t JOIN document_tags dt ON t.id = dt.tag_id "
|
|
"WHERE dt.document_id = ? ORDER BY t.name",
|
|
(doc_id,),
|
|
).fetchall()
|
|
assert [r["name"] for r in rows] == ["a", "c"]
|
|
|
|
|
|
def test_tags_are_lowercase(db):
|
|
tag_id = get_or_create_tag(db, "MyTag")
|
|
db.commit()
|
|
row = db.execute("SELECT name FROM tags WHERE id = ?", (tag_id,)).fetchone()
|
|
assert row["name"] == "mytag"
|
|
|
|
|
|
def test_recreate_vec_table(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash7", "note")
|
|
chunk_id = insert_chunk(db, doc_id, 0, "text")
|
|
insert_embedding(db, chunk_id, [0.1] * 384)
|
|
db.commit()
|
|
|
|
recreate_vec_table(db, 768)
|
|
# Old data gone, new dimension
|
|
rows = db.execute("SELECT * FROM chunks_vec").fetchall()
|
|
assert len(rows) == 0
|
|
|
|
|
|
def test_cascade_delete(db):
|
|
doc_id = insert_document(db, "Doc", None, "hash8", "pdf")
|
|
insert_chunk(db, doc_id, 0, "chunk text")
|
|
tag_document(db, doc_id, ["test"])
|
|
db.commit()
|
|
|
|
db.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
|
db.commit()
|
|
|
|
assert db.execute("SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
|
|
assert db.execute("SELECT COUNT(*) FROM document_tags WHERE document_id = ?", (doc_id,)).fetchone()[0] == 0
|