51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
"""Tests for embedding model management."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import click
|
|
import pytest
|
|
|
|
from kb_search.embeddings import check_model_binding
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_conn():
|
|
"""Mock DB connection with config values."""
|
|
def make_conn(config_values=None):
|
|
config_values = config_values or {}
|
|
conn = MagicMock()
|
|
def mock_execute(sql, params=None):
|
|
if "SELECT value FROM config" in sql and params:
|
|
key = params[0]
|
|
val = config_values.get(key)
|
|
row = MagicMock()
|
|
row.__getitem__ = lambda self, k: val
|
|
result = MagicMock()
|
|
result.fetchone.return_value = row if val else None
|
|
return result
|
|
return MagicMock()
|
|
conn.execute = mock_execute
|
|
return conn
|
|
return make_conn
|
|
|
|
|
|
def test_model_binding_match(mock_conn):
|
|
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
|
cfg = {"embedding": {"model": "all-MiniLM-L6-v2"}}
|
|
# Should not raise
|
|
check_model_binding(conn, cfg)
|
|
|
|
|
|
def test_model_binding_mismatch(mock_conn):
|
|
conn = mock_conn({"model_name": "all-MiniLM-L6-v2", "embedding_dim": "384"})
|
|
cfg = {"embedding": {"model": "nomic-embed-text"}}
|
|
with pytest.raises(click.ClickException, match="Model mismatch"):
|
|
check_model_binding(conn, cfg)
|
|
|
|
|
|
def test_model_binding_no_db_model(mock_conn):
|
|
conn = mock_conn({})
|
|
cfg = {"embedding": {"model": "anything"}}
|
|
# Should not raise when DB not yet initialised
|
|
check_model_binding(conn, cfg)
|