9aab79d49b
- Remove v1 Python CLI (src/kb_search/, tests/, root pyproject.toml, uv.lock, .venv) - Add Go client with cross-platform build (client/) - Add FastAPI engine with NVIDIA and multi-stage ROCm Dockerfiles (engine/) - Add VERSION files for client and engine, wired into builds - Add release.sh for automated build, tag, release, and Docker push - Update README with build/release docs and ROCm migration note - Clean up .gitignore for v2 project structure Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
110 lines
2.8 KiB
Python
110 lines
2.8 KiB
Python
"""Embedding model management and text embedding utilities."""
|
|
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logger = logging.getLogger("kb.embeddings")
|
|
|
|
_model: Optional[SentenceTransformer] = None
|
|
_model_dim: Optional[int] = None
|
|
|
|
|
|
def _resolve_device(device: str) -> str:
|
|
"""Resolve device string, mapping 'auto' to the best available device."""
|
|
if device == "auto":
|
|
resolved = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info("Auto-resolved device to '%s'", resolved)
|
|
return resolved
|
|
return device
|
|
|
|
|
|
def load_model(model_name: str, device: str = "cpu") -> int:
|
|
"""Load a sentence-transformers model and return its embedding dimension.
|
|
|
|
The model is cached at module level so subsequent calls are no-ops unless
|
|
the module globals are cleared.
|
|
|
|
Args:
|
|
model_name: HuggingFace model name or local path.
|
|
device: Target device — "cpu", "cuda", or "auto".
|
|
|
|
Returns:
|
|
The embedding dimension of the loaded model.
|
|
"""
|
|
global _model, _model_dim
|
|
|
|
resolved_device = _resolve_device(device)
|
|
|
|
if resolved_device == "cuda":
|
|
backend = "torch"
|
|
else:
|
|
backend = "onnx"
|
|
|
|
logger.info(
|
|
"Loading model '%s' on device '%s' (backend=%s)",
|
|
model_name,
|
|
resolved_device,
|
|
backend,
|
|
)
|
|
|
|
_model = SentenceTransformer(
|
|
model_name,
|
|
device=resolved_device,
|
|
backend=backend,
|
|
)
|
|
_model_dim = _model.get_sentence_embedding_dimension()
|
|
|
|
logger.info("Model loaded — embedding dimension: %d", _model_dim)
|
|
return _model_dim
|
|
|
|
|
|
def get_model_dim() -> int:
|
|
"""Return the embedding dimension of the loaded model.
|
|
|
|
Raises:
|
|
RuntimeError: If no model has been loaded yet.
|
|
"""
|
|
if _model_dim is None:
|
|
raise RuntimeError(
|
|
"Embedding model not loaded. Call load_model() first."
|
|
)
|
|
return _model_dim
|
|
|
|
|
|
def embed_texts(
|
|
texts: list[str],
|
|
prefix: str = "",
|
|
show_progress: bool = False,
|
|
) -> list[list[float]]:
|
|
"""Embed a list of texts using the cached model.
|
|
|
|
Args:
|
|
texts: Strings to embed.
|
|
prefix: Optional prefix prepended to each text before encoding.
|
|
show_progress: Whether to display a progress bar.
|
|
|
|
Returns:
|
|
A list of embedding vectors (each a list of floats).
|
|
|
|
Raises:
|
|
RuntimeError: If no model has been loaded yet.
|
|
"""
|
|
if _model is None:
|
|
raise RuntimeError(
|
|
"Embedding model not loaded. Call load_model() first."
|
|
)
|
|
|
|
if prefix:
|
|
texts = [prefix + t for t in texts]
|
|
|
|
embeddings = _model.encode(
|
|
texts,
|
|
show_progress_bar=show_progress,
|
|
convert_to_numpy=True,
|
|
)
|
|
|
|
return embeddings.tolist()
|