Files
kb/engine/kb/embeddings.py
T
steve 9aab79d49b v2 restructure: Go client, Docker engine, release tooling
- Remove v1 Python CLI (src/kb_search/, tests/, root pyproject.toml, uv.lock, .venv)
- Add Go client with cross-platform build (client/)
- Add FastAPI engine with NVIDIA and multi-stage ROCm Dockerfiles (engine/)
- Add VERSION files for client and engine, wired into builds
- Add release.sh for automated build, tag, release, and Docker push
- Update README with build/release docs and ROCm migration note
- Clean up .gitignore for v2 project structure

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 21:52:25 +00:00

110 lines
2.8 KiB
Python

"""Embedding model management and text embedding utilities."""
import logging
from typing import Optional
import torch
from sentence_transformers import SentenceTransformer
logger = logging.getLogger("kb.embeddings")
_model: Optional[SentenceTransformer] = None
_model_dim: Optional[int] = None
def _resolve_device(device: str) -> str:
"""Resolve device string, mapping 'auto' to the best available device."""
if device == "auto":
resolved = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Auto-resolved device to '%s'", resolved)
return resolved
return device
def load_model(model_name: str, device: str = "cpu") -> int:
"""Load a sentence-transformers model and return its embedding dimension.
The model is cached at module level so subsequent calls are no-ops unless
the module globals are cleared.
Args:
model_name: HuggingFace model name or local path.
device: Target device — "cpu", "cuda", or "auto".
Returns:
The embedding dimension of the loaded model.
"""
global _model, _model_dim
resolved_device = _resolve_device(device)
if resolved_device == "cuda":
backend = "torch"
else:
backend = "onnx"
logger.info(
"Loading model '%s' on device '%s' (backend=%s)",
model_name,
resolved_device,
backend,
)
_model = SentenceTransformer(
model_name,
device=resolved_device,
backend=backend,
)
_model_dim = _model.get_sentence_embedding_dimension()
logger.info("Model loaded — embedding dimension: %d", _model_dim)
return _model_dim
def get_model_dim() -> int:
"""Return the embedding dimension of the loaded model.
Raises:
RuntimeError: If no model has been loaded yet.
"""
if _model_dim is None:
raise RuntimeError(
"Embedding model not loaded. Call load_model() first."
)
return _model_dim
def embed_texts(
texts: list[str],
prefix: str = "",
show_progress: bool = False,
) -> list[list[float]]:
"""Embed a list of texts using the cached model.
Args:
texts: Strings to embed.
prefix: Optional prefix prepended to each text before encoding.
show_progress: Whether to display a progress bar.
Returns:
A list of embedding vectors (each a list of floats).
Raises:
RuntimeError: If no model has been loaded yet.
"""
if _model is None:
raise RuntimeError(
"Embedding model not loaded. Call load_model() first."
)
if prefix:
texts = [prefix + t for t in texts]
embeddings = _model.encode(
texts,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
return embeddings.tolist()