Initial MVP
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
"""Tests for embedding model management."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from kb_search.embeddings import check_model_binding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conn():
|
||||
"""Mock DB connection with config values."""
|
||||
def make_conn(config_values=None):
|
||||
config_values = config_values or {}
|
||||
conn = MagicMock()
|
||||
def mock_execute(sql, params=None):
|
||||
if "SELECT value FROM config" in sql and params:
|
||||
key = params[0]
|
||||
val = config_values.get(key)
|
||||
row = MagicMock()
|
||||
row.__getitem__ = lambda self, k: val
|
||||
result = MagicMock()
|
||||
result.fetchone.return_value = row if val else None
|
||||
return result
|
||||
return MagicMock()
|
||||
conn.execute = mock_execute
|
||||
return conn
|
||||
return make_conn
|
||||
|
||||
|
||||
def test_model_binding_match(mock_conn):
|
||||
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
||||
cfg = {"embedding": {"model": "all-MiniLM-L6-v2"}}
|
||||
# Should not raise
|
||||
check_model_binding(conn, cfg)
|
||||
|
||||
|
||||
def test_model_binding_mismatch(mock_conn):
|
||||
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
||||
cfg = {"embedding": {"model": "nomic-embed-text"}}
|
||||
with pytest.raises(click.ClickException, match="Model mismatch"):
|
||||
check_model_binding(conn, cfg)
|
||||
|
||||
|
||||
def test_model_binding_no_db_model(mock_conn):
|
||||
conn = mock_conn({})
|
||||
cfg = {"embedding": {"model": "anything"}}
|
||||
# Should not raise when DB not yet initialised
|
||||
check_model_binding(conn, cfg)
|
||||
Reference in New Issue
Block a user