v2 restructure: Go client, Docker engine, release tooling

- Remove v1 Python CLI (src/kb_search/, tests/, root pyproject.toml, uv.lock, .venv)
- Add Go client with cross-platform build (client/)
- Add FastAPI engine with NVIDIA and multi-stage ROCm Dockerfiles (engine/)
- Add VERSION files for client and engine, wired into builds
- Add release.sh for automated build, tag, release, and Docker push
- Update README with build/release docs and ROCm migration note
- Clean up .gitignore for v2 project structure

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 21:52:25 +00:00
parent 2030976b85
commit 9aab79d49b
98 changed files with 4526 additions and 7776 deletions
+6
View File
@@ -0,0 +1,6 @@
__pycache__/
*.pyc
.venv/
*.egg-info/
.pytest_cache/
tests/
+35
View File
@@ -0,0 +1,35 @@
FROM nvidia/cuda:13.0.1-runtime-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.12 python3.12-venv python3.12-dev python3-pip \
libpoppler-cpp-dev poppler-utils \
libgl1 libglib2.0-0 \
build-essential curl \
&& rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
COPY pyproject.toml ./
COPY kb/ kb/
COPY main.py ./
COPY VERSION ./
RUN uv venv .venv && \
. .venv/bin/activate && \
uv pip install -e . && \
uv pip install --no-deps onnxruntime-gpu
ENV PATH="/app/.venv/bin:$PATH"
ENV VIRTUAL_ENV="/app/.venv"
ENV KB_DEVICE=auto
ENV KB_INGEST_DEVICE=auto
ENV KB_DATA_DIR=/data
EXPOSE 8000
VOLUME ["/data"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
+68
View File
@@ -0,0 +1,68 @@
# Stage 1: Build — install Python deps with dev tools available
FROM rocm/dev-ubuntu-24.04:6.4-complete AS builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.12 python3.12-venv python3.12-dev python3-pip \
libpoppler-cpp-dev poppler-utils \
build-essential curl \
&& rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
COPY pyproject.toml ./
COPY kb/ kb/
COPY main.py ./
COPY VERSION ./
RUN uv venv .venv && \
. .venv/bin/activate && \
uv pip install -e . && \
uv pip install --no-deps onnxruntime-rocm
# Stage 2: Runtime — minimal ROCm runtime libs only
FROM ubuntu:24.04
ENV DEBIAN_FRONTEND=noninteractive
# Add ROCm apt repository
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates curl gnupg \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key \
| gpg --dearmor -o /etc/apt/keyrings/rocm.gpg \
&& echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/6.4.1 noble main" \
> /etc/apt/sources.list.d/rocm.list \
&& printf 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600\n' \
> /etc/apt/preferences.d/rocm-pin-600 \
&& apt-get update && apt-get install -y --no-install-recommends \
python3.12 python3.12-venv \
libpoppler-cpp0t64 poppler-utils \
libgl1 libglib2.0-0 \
rocm-hip-runtime \
rocm-hip-libraries \
miopen-hip \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Copy built venv and application from builder
COPY --from=builder /app/.venv .venv
COPY --from=builder /app/kb kb
COPY --from=builder /app/main.py .
COPY --from=builder /app/pyproject.toml .
COPY --from=builder /app/VERSION .
ENV PATH="/app/.venv/bin:$PATH"
ENV VIRTUAL_ENV="/app/.venv"
ENV KB_DEVICE=auto
ENV KB_INGEST_DEVICE=auto
ENV KB_DATA_DIR=/data
EXPOSE 8000
VOLUME ["/data"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
+1
View File
@@ -0,0 +1 @@
2.0.3
+24
View File
@@ -0,0 +1,24 @@
services:
kb-engine:
build:
context: .
dockerfile: Dockerfile.nvidia
runtime: nvidia
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
ports:
- "${KB_PORT:-8000}:8000"
volumes:
- ${KB_DATA_PATH:-./data}:/data
environment:
- KB_MODEL=${KB_MODEL:-all-MiniLM-L6-v2}
- KB_DEVICE=${KB_DEVICE:-auto}
- KB_INGEST_DEVICE=${KB_INGEST_DEVICE:-auto}
- KB_API_KEY=${KB_API_KEY:-}
- KB_SEARCH_THRESHOLD=${KB_SEARCH_THRESHOLD:-0.01}
restart: unless-stopped
+21
View File
@@ -0,0 +1,21 @@
services:
kb-engine:
build:
context: .
dockerfile: Dockerfile.rocm
devices:
- "/dev/kfd"
- "/dev/dri"
group_add:
- "video"
ports:
- "${KB_PORT:-8000}:8000"
volumes:
- ${KB_DATA_PATH:-./data}:/data
environment:
- KB_MODEL=${KB_MODEL:-all-MiniLM-L6-v2}
- KB_DEVICE=${KB_DEVICE:-auto}
- KB_INGEST_DEVICE=${KB_INGEST_DEVICE:-auto}
- KB_API_KEY=${KB_API_KEY:-}
- KB_SEARCH_THRESHOLD=${KB_SEARCH_THRESHOLD:-0.01}
restart: unless-stopped
View File
+44
View File
@@ -0,0 +1,44 @@
"""Engine configuration — all settings from environment variables."""
import os
from pathlib import Path
class Config:
data_dir: Path
model: str
device: str
ingest_device: str
api_key: str | None
host: str
port: int
def __init__(self):
self.data_dir = Path(os.environ.get("KB_DATA_DIR", "/data"))
self.model = os.environ.get("KB_MODEL", "all-MiniLM-L6-v2")
self.device = os.environ.get("KB_DEVICE", "auto")
self.ingest_device = os.environ.get("KB_INGEST_DEVICE", "auto")
self.api_key = os.environ.get("KB_API_KEY") or None
self.search_threshold = float(os.environ.get("KB_SEARCH_THRESHOLD", "0.01"))
self.host = os.environ.get("KB_HOST", "0.0.0.0")
self.port = int(os.environ.get("KB_PORT", "8000"))
@property
def db_path(self) -> Path:
return self.data_dir / "kb.db"
@property
def hf_cache(self) -> Path:
return self.data_dir / "hf_cache"
@property
def staging_dir(self) -> Path:
return self.data_dir / "staging"
def ensure_dirs(self):
self.data_dir.mkdir(parents=True, exist_ok=True)
self.hf_cache.mkdir(exist_ok=True)
self.staging_dir.mkdir(exist_ok=True)
cfg = Config()
+308
View File
@@ -0,0 +1,308 @@
"""Database module for kb-engine v2.
Provides SQLite database access with WAL mode, FTS5 full-text search,
and sqlite-vec vector storage for embeddings.
"""
import json
import sqlite3
import struct
from typing import Any, Optional
def get_connection(db_path: str) -> sqlite3.Connection:
"""Return a sqlite3 connection with WAL mode, Row factory, and foreign keys enabled."""
import sqlite_vec
conn = sqlite3.connect(db_path)
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def init_schema(conn: sqlite3.Connection, embedding_dim: int) -> None:
"""Create all tables if they do not already exist."""
conn.executescript(f"""
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY,
title TEXT,
source_path TEXT,
content_hash TEXT UNIQUE,
doc_type TEXT,
language TEXT,
created_at TEXT DEFAULT current_timestamp
);
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY,
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
chunk_index INTEGER,
text TEXT,
token_count INTEGER,
metadata TEXT DEFAULT '{{}}',
UNIQUE(document_id, chunk_index)
);
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
content=chunks,
content_rowid=id
);
-- Triggers to keep FTS index in sync with chunks table
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
END;
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES ('delete', old.id, old.text);
END;
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES ('delete', old.id, old.text);
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
END;
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE COLLATE NOCASE
);
CREATE TABLE IF NOT EXISTS document_tags (
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
tag_id INTEGER REFERENCES tags(id) ON DELETE CASCADE,
UNIQUE(document_id, tag_id)
);
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT
);
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY,
filename TEXT,
status TEXT DEFAULT 'queued' CHECK(status IN ('queued','processing','done','failed','skipped')),
doc_type TEXT,
tags_json TEXT DEFAULT '[]',
title TEXT,
error TEXT,
document_id INTEGER,
chunk_count INTEGER DEFAULT 0,
staging_path TEXT,
created_at TEXT DEFAULT current_timestamp,
completed_at TEXT
);
""")
# sqlite-vec virtual table (cannot use IF NOT EXISTS with vec0, so check first)
existing = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks_vec'"
).fetchone()
if not existing:
conn.execute(
f"CREATE VIRTUAL TABLE chunks_vec USING vec0(embedding float[{embedding_dim}], chunk_id integer)"
)
conn.commit()
# ---------------------------------------------------------------------------
# Config helpers
# ---------------------------------------------------------------------------
def get_db_config(conn: sqlite3.Connection, key: str, default: Optional[str] = None) -> Optional[str]:
"""Retrieve a value from the config table."""
row = conn.execute("SELECT value FROM config WHERE key = ?", (key,)).fetchone()
return row["value"] if row else default
def set_db_config(conn: sqlite3.Connection, key: str, value: str) -> None:
"""Insert or update a value in the config table."""
conn.execute(
"INSERT INTO config(key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key, value),
)
conn.commit()
# ---------------------------------------------------------------------------
# Document helpers
# ---------------------------------------------------------------------------
def hash_exists(conn: sqlite3.Connection, content_hash: str) -> bool:
"""Check whether a document with the given content hash is already ingested."""
row = conn.execute(
"SELECT 1 FROM documents WHERE content_hash = ?", (content_hash,)
).fetchone()
return row is not None
def insert_document(
conn: sqlite3.Connection,
title: str,
source_path: str,
content_hash: str,
doc_type: str,
language: Optional[str] = None,
) -> int:
"""Insert a new document and return its id."""
cur = conn.execute(
"INSERT INTO documents(title, source_path, content_hash, doc_type, language) VALUES (?, ?, ?, ?, ?)",
(title, source_path, content_hash, doc_type, language),
)
conn.commit()
return cur.lastrowid
# ---------------------------------------------------------------------------
# Chunk / embedding helpers
# ---------------------------------------------------------------------------
def insert_chunk(
conn: sqlite3.Connection,
document_id: int,
chunk_index: int,
text: str,
token_count: Optional[int] = None,
metadata: Any = None,
) -> int:
"""Insert a chunk and return its id. *metadata* is JSON-serialized if it is a dict."""
if metadata is None:
metadata_str = "{}"
elif isinstance(metadata, dict):
metadata_str = json.dumps(metadata)
else:
metadata_str = str(metadata)
cur = conn.execute(
"INSERT INTO chunks(document_id, chunk_index, text, token_count, metadata) VALUES (?, ?, ?, ?, ?)",
(document_id, chunk_index, text, token_count, metadata_str),
)
conn.commit()
return cur.lastrowid
def insert_embedding(conn: sqlite3.Connection, chunk_id: int, embedding: list[float]) -> None:
"""Insert an embedding vector into chunks_vec using struct-packed floats."""
blob = struct.pack(f"{len(embedding)}f", *embedding)
conn.execute(
"INSERT INTO chunks_vec(embedding, chunk_id) VALUES (?, ?)",
(blob, chunk_id),
)
conn.commit()
# ---------------------------------------------------------------------------
# Tagging helpers
# ---------------------------------------------------------------------------
def tag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]) -> None:
"""Create tags if needed and associate them with a document."""
for name in tag_names:
conn.execute("INSERT OR IGNORE INTO tags(name) VALUES (?)", (name,))
tag_id = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()["id"]
conn.execute(
"INSERT OR IGNORE INTO document_tags(document_id, tag_id) VALUES (?, ?)",
(document_id, tag_id),
)
conn.commit()
def untag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]) -> None:
"""Remove tag associations from a document."""
for name in tag_names:
row = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()
if row:
conn.execute(
"DELETE FROM document_tags WHERE document_id = ? AND tag_id = ?",
(document_id, row["id"]),
)
conn.commit()
# ---------------------------------------------------------------------------
# Vec table management
# ---------------------------------------------------------------------------
def recreate_vec_table(conn: sqlite3.Connection, dim: int) -> None:
"""Drop and recreate the chunks_vec virtual table (for reindexing)."""
conn.execute("DROP TABLE IF EXISTS chunks_vec")
conn.execute(
f"CREATE VIRTUAL TABLE chunks_vec USING vec0(embedding float[{dim}], chunk_id integer)"
)
conn.commit()
# ---------------------------------------------------------------------------
# Job CRUD
# ---------------------------------------------------------------------------
_TERMINAL_STATUSES = {"done", "failed", "skipped"}
def create_job(
conn: sqlite3.Connection,
filename: str,
staging_path: str,
doc_type: Optional[str] = None,
tags_json: str = "[]",
title: Optional[str] = None,
) -> int:
"""Create a new ingest job and return its id."""
cur = conn.execute(
"INSERT INTO jobs(filename, staging_path, doc_type, tags_json, title) VALUES (?, ?, ?, ?, ?)",
(filename, staging_path, doc_type, tags_json, title),
)
conn.commit()
return cur.lastrowid
def get_job(conn: sqlite3.Connection, job_id: int) -> Optional[sqlite3.Row]:
"""Return a single job row by id, or None."""
return conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
def list_jobs(conn: sqlite3.Connection, status: Optional[str] = None) -> list[sqlite3.Row]:
"""Return jobs ordered newest first, optionally filtered by status."""
if status:
return conn.execute(
"SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC", (status,)
).fetchall()
return conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
def update_job_status(
conn: sqlite3.Connection,
job_id: int,
status: str,
error: Optional[str] = None,
document_id: Optional[int] = None,
chunk_count: Optional[int] = None,
) -> None:
"""Update a job's status and related fields. Sets completed_at for terminal states."""
fields = ["status = ?"]
params: list[Any] = [status]
if error is not None:
fields.append("error = ?")
params.append(error)
if document_id is not None:
fields.append("document_id = ?")
params.append(document_id)
if chunk_count is not None:
fields.append("chunk_count = ?")
params.append(chunk_count)
if status in _TERMINAL_STATUSES:
fields.append("completed_at = current_timestamp")
params.append(job_id)
conn.execute(f"UPDATE jobs SET {', '.join(fields)} WHERE id = ?", params)
conn.commit()
+109
View File
@@ -0,0 +1,109 @@
"""Embedding model management and text embedding utilities."""
import logging
from typing import Optional
import torch
from sentence_transformers import SentenceTransformer
logger = logging.getLogger("kb.embeddings")
_model: Optional[SentenceTransformer] = None
_model_dim: Optional[int] = None
def _resolve_device(device: str) -> str:
"""Resolve device string, mapping 'auto' to the best available device."""
if device == "auto":
resolved = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Auto-resolved device to '%s'", resolved)
return resolved
return device
def load_model(model_name: str, device: str = "cpu") -> int:
"""Load a sentence-transformers model and return its embedding dimension.
The model is cached at module level so subsequent calls are no-ops unless
the module globals are cleared.
Args:
model_name: HuggingFace model name or local path.
device: Target device — "cpu", "cuda", or "auto".
Returns:
The embedding dimension of the loaded model.
"""
global _model, _model_dim
resolved_device = _resolve_device(device)
if resolved_device == "cuda":
backend = "torch"
else:
backend = "onnx"
logger.info(
"Loading model '%s' on device '%s' (backend=%s)",
model_name,
resolved_device,
backend,
)
_model = SentenceTransformer(
model_name,
device=resolved_device,
backend=backend,
)
_model_dim = _model.get_sentence_embedding_dimension()
logger.info("Model loaded — embedding dimension: %d", _model_dim)
return _model_dim
def get_model_dim() -> int:
"""Return the embedding dimension of the loaded model.
Raises:
RuntimeError: If no model has been loaded yet.
"""
if _model_dim is None:
raise RuntimeError(
"Embedding model not loaded. Call load_model() first."
)
return _model_dim
def embed_texts(
texts: list[str],
prefix: str = "",
show_progress: bool = False,
) -> list[list[float]]:
"""Embed a list of texts using the cached model.
Args:
texts: Strings to embed.
prefix: Optional prefix prepended to each text before encoding.
show_progress: Whether to display a progress bar.
Returns:
A list of embedding vectors (each a list of floats).
Raises:
RuntimeError: If no model has been loaded yet.
"""
if _model is None:
raise RuntimeError(
"Embedding model not loaded. Call load_model() first."
)
if prefix:
texts = [prefix + t for t in texts]
embeddings = _model.encode(
texts,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
return embeddings.tolist()
View File
+206
View File
@@ -0,0 +1,206 @@
"""Chunking pipeline for source code files."""
from __future__ import annotations
import ast
import re
def _approx_tokens(text: str) -> int:
return len(text) // 4
def _fixed_token_chunks(text: str, max_tokens: int) -> list[str]:
"""Split text into fixed-size token chunks by lines."""
lines = text.split("\n")
pieces: list[str] = []
current: list[str] = []
current_len = 0
for line in lines:
line_tokens = _approx_tokens(line)
if current and current_len + line_tokens > max_tokens:
pieces.append("\n".join(current))
current = [line]
current_len = line_tokens
else:
current.append(line)
current_len += line_tokens
if current:
pieces.append("\n".join(current))
return pieces
def _chunk_python(text: str, max_tokens: int) -> list[dict]:
"""Use the ast module to extract top-level classes and functions."""
lines = text.split("\n")
try:
tree = ast.parse(text)
except SyntaxError:
return []
# Collect top-level class and function definitions
regions: list[tuple[int, int, str]] = [] # (start_line, end_line, name)
for node in ast.iter_child_nodes(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
continue
start = node.lineno - 1 # 0-indexed
# Include preceding comments and decorators
first_line = node.lineno - 1
if node.decorator_list:
first_line = node.decorator_list[0].lineno - 1
# Walk backwards for comment lines
scan = first_line - 1
while scan >= 0 and (lines[scan].strip().startswith("#") or not lines[scan].strip()):
if lines[scan].strip().startswith("#"):
first_line = scan
scan -= 1
start = first_line
end = node.end_lineno # 1-indexed, inclusive
prefix = "class " if isinstance(node, ast.ClassDef) else "def "
regions.append((start, end, f"{prefix}{node.name}"))
if not regions:
return []
# Sort by start line
regions.sort(key=lambda r: r[0])
chunks: list[dict] = []
chunk_index = 0
prev_end = 0
for start, end, name in regions:
# Capture any module-level code between definitions
if start > prev_end:
preamble = "\n".join(lines[prev_end:start]).strip()
if preamble and _approx_tokens(preamble) > 10:
chunks.append({
"text": preamble,
"chunk_index": chunk_index,
"metadata": {},
})
chunk_index += 1
block = "\n".join(lines[start:end]).rstrip()
if _approx_tokens(block) > max_tokens:
for piece in _fixed_token_chunks(block, max_tokens):
chunks.append({
"text": piece,
"chunk_index": chunk_index,
"metadata": {"name": name},
})
chunk_index += 1
else:
chunks.append({
"text": block,
"chunk_index": chunk_index,
"metadata": {"name": name},
})
chunk_index += 1
prev_end = end
# Trailing module-level code
if prev_end < len(lines):
tail = "\n".join(lines[prev_end:]).strip()
if tail and _approx_tokens(tail) > 10:
chunks.append({
"text": tail,
"chunk_index": chunk_index,
"metadata": {},
})
return chunks
def _chunk_by_regex(text: str, pattern: str, max_tokens: int) -> list[dict]:
"""Split source code at regex-matched function boundaries."""
lines = text.split("\n")
boundaries: list[tuple[int, str]] = []
for i, line in enumerate(lines):
m = re.match(pattern, line)
if m:
boundaries.append((i, m.group(0).strip()))
if not boundaries:
return []
chunks: list[dict] = []
chunk_index = 0
# Content before first match
if boundaries[0][0] > 0:
preamble = "\n".join(lines[: boundaries[0][0]]).strip()
if preamble:
chunks.append({
"text": preamble,
"chunk_index": chunk_index,
"metadata": {},
})
chunk_index += 1
for idx, (start, name) in enumerate(boundaries):
end = boundaries[idx + 1][0] if idx + 1 < len(boundaries) else len(lines)
block = "\n".join(lines[start:end]).rstrip()
if _approx_tokens(block) > max_tokens:
for piece in _fixed_token_chunks(block, max_tokens):
chunks.append({
"text": piece,
"chunk_index": chunk_index,
"metadata": {"name": name},
})
chunk_index += 1
else:
chunks.append({
"text": block,
"chunk_index": chunk_index,
"metadata": {"name": name},
})
chunk_index += 1
return chunks
def chunk_code(
text: str,
language: str | None,
max_tokens: int = 1024,
) -> list[dict]:
"""Split source code into chunks using language-aware strategies.
Returns a list of chunk dicts, each containing:
text, chunk_index, metadata
"""
chunks: list[dict] = []
if language == "python":
chunks = _chunk_python(text, max_tokens)
elif language == "bash":
chunks = _chunk_by_regex(
text, r"^(?:\w+\s*\(\)|function\s+\w+)", max_tokens
)
elif language == "go":
chunks = _chunk_by_regex(text, r"^func\s+", max_tokens)
# Fallback: fixed-size token chunking
if not chunks:
for idx, piece in enumerate(_fixed_token_chunks(text, max_tokens)):
piece = piece.strip()
if piece:
chunks.append({
"text": piece,
"chunk_index": idx,
"metadata": {},
})
return chunks
+47
View File
@@ -0,0 +1,47 @@
"""Detect document type and language from file extension."""
from pathlib import Path
SUPPORTED_EXTENSIONS: dict[str, tuple[str, str | None]] = {
".pdf": ("pdf", None),
".docx": ("pdf", None),
".html": ("pdf", None),
".md": ("markdown", None),
".txt": ("note", None),
".py": ("code", "python"),
".sh": ("code", "bash"),
".go": ("code", "go"),
}
def is_supported(path: Path) -> bool:
"""Check if the file extension is supported for ingestion."""
return path.suffix.lower() in SUPPORTED_EXTENSIONS
def detect_type(
path: Path,
force_type: str | None = None,
force_language: str | None = None,
) -> tuple[str, str | None]:
"""Return (doc_type, language) for the given file path.
Uses force_type / force_language when provided, otherwise falls back to
extension-based lookup. Raises ValueError for unsupported extensions.
"""
ext = path.suffix.lower()
if ext not in SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported file extension '{ext}'. "
f"Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}"
)
doc_type, language = SUPPORTED_EXTENSIONS[ext]
if force_type is not None:
doc_type = force_type
if force_language is not None:
language = force_language
return doc_type, language
+107
View File
@@ -0,0 +1,107 @@
"""Chunking pipeline for PDF / DOCX / HTML documents via Docling."""
from __future__ import annotations
import logging
from pathlib import Path
# Suppress noisy Docling / RapidOCR logging
for _logger_name in (
"docling",
"docling.document_converter",
"docling_core",
"rapidocr",
"rapidocr_onnxruntime",
):
logging.getLogger(_logger_name).setLevel(logging.WARNING)
from docling.datamodel.base_models import InputFormat # noqa: E402
from docling.datamodel.pipeline_options import ( # noqa: E402
AcceleratorOptions,
PdfPipelineOptions,
RapidOcrOptions,
)
from docling.document_converter import DocumentConverter, PdfFormatOption # noqa: E402
from docling_core.transforms.chunker.hierarchical_chunker import ( # noqa: E402
HierarchicalChunker,
)
def _fixed_size_chunks(text: str, max_chars: int = 2000) -> list[str]:
"""Split text into fixed-size pieces as a fallback."""
chunks: list[str] = []
for i in range(0, len(text), max_chars):
chunk = text[i : i + max_chars].strip()
if chunk:
chunks.append(chunk)
return chunks
def chunk_document(
file_path: Path,
ingest_device: str = "cpu",
) -> list[dict]:
"""Convert and chunk a PDF/DOCX/HTML document using Docling.
Returns a list of chunk dicts, each containing:
text, chunk_index, metadata
"""
accelerator = AcceleratorOptions(device=ingest_device)
ocr_options = RapidOcrOptions()
pipeline_options = PdfPipelineOptions(
accelerator_options=accelerator,
do_ocr=True,
ocr_options=ocr_options,
bitmap_area_threshold=0.25,
)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options),
}
)
result = converter.convert(str(file_path))
doc = result.document
# Primary: hierarchical chunking
chunker = HierarchicalChunker()
raw_chunks = list(chunker.chunk(doc))
chunks: list[dict] = []
for idx, chunk in enumerate(raw_chunks):
text = chunk.text.strip() if hasattr(chunk, "text") else str(chunk).strip()
if not text:
continue
metadata: dict = {}
# Extract page numbers from chunk metadata when available
if hasattr(chunk, "meta") and chunk.meta:
meta = chunk.meta
if hasattr(meta, "page") and meta.page is not None:
metadata["page"] = meta.page
if hasattr(meta, "pages") and meta.pages:
metadata["pages"] = meta.pages
if hasattr(meta, "headings") and meta.headings:
metadata["section_header"] = " > ".join(meta.headings)
chunks.append({
"text": text,
"chunk_index": idx,
"metadata": metadata,
})
# Fallback: fixed-size chunking when hierarchy produces nothing
if not chunks:
full_text = doc.export_to_text() if hasattr(doc, "export_to_text") else ""
if not full_text and hasattr(doc, "text"):
full_text = doc.text
for idx, piece in enumerate(_fixed_size_chunks(full_text)):
chunks.append({
"text": piece,
"chunk_index": idx,
"metadata": {},
})
return chunks
+130
View File
@@ -0,0 +1,130 @@
"""Chunking pipeline for Markdown documents."""
from __future__ import annotations
import re
def _approx_tokens(text: str) -> int:
return len(text) // 4
def _split_at_paragraphs(text: str, max_tokens: int) -> list[str]:
"""Split text into pieces that fit within max_tokens at paragraph boundaries."""
paragraphs = re.split(r"\n{2,}", text)
pieces: list[str] = []
current: list[str] = []
current_len = 0
for para in paragraphs:
para_tokens = _approx_tokens(para)
if current and current_len + para_tokens > max_tokens:
pieces.append("\n\n".join(current))
current = [para]
current_len = para_tokens
else:
current.append(para)
current_len += para_tokens
if current:
pieces.append("\n\n".join(current))
return pieces
def chunk_markdown(
text: str,
max_tokens: int = 1024,
min_tokens: int = 50,
) -> list[dict]:
"""Split markdown text into chunks based on heading structure.
Returns a list of chunk dicts, each containing:
text, chunk_index, metadata
"""
lines = text.split("\n")
# Split into sections by headings
sections: list[tuple[str | None, str]] = [] # (heading, body)
current_heading: str | None = None
current_lines: list[str] = []
for line in lines:
if re.match(r"^#{1,6}\s+", line):
# Flush previous section
if current_lines or current_heading is not None:
sections.append((current_heading, "\n".join(current_lines).strip()))
current_heading = line.strip()
current_lines = []
else:
current_lines.append(line)
# Flush last section
if current_lines or current_heading is not None:
sections.append((current_heading, "\n".join(current_lines).strip()))
# If there was content before any heading, capture it
if not sections:
sections.append((None, text.strip()))
# Build heading hierarchy for context
heading_stack: list[tuple[int, str]] = []
def _update_stack(heading: str | None) -> str | None:
if heading is None:
return " > ".join(h for _, h in heading_stack) if heading_stack else None
match = re.match(r"^(#{1,6})\s+(.*)", heading)
if not match:
return heading
level = len(match.group(1))
title = match.group(2).strip()
# Pop deeper or equal headings
while heading_stack and heading_stack[-1][0] >= level:
heading_stack.pop()
heading_stack.append((level, title))
return " > ".join(h for _, h in heading_stack)
# Merge small sections with the next section
merged: list[tuple[str | None, str]] = []
for heading, body in sections:
section_text = f"{heading}\n{body}".strip() if heading else body
if merged and _approx_tokens(merged[-1][1]) < min_tokens:
prev_heading, prev_body = merged[-1]
merged[-1] = (prev_heading, f"{prev_body}\n\n{section_text}".strip())
else:
merged.append((heading, section_text))
# Build chunks, splitting large sections at paragraph boundaries
chunks: list[dict] = []
chunk_index = 0
for heading, body in merged:
section_header = _update_stack(heading)
tokens = _approx_tokens(body)
if tokens <= max_tokens:
if body:
metadata: dict = {}
if section_header:
metadata["section_header"] = section_header
chunks.append({
"text": body,
"chunk_index": chunk_index,
"metadata": metadata,
})
chunk_index += 1
else:
# Split large section at paragraph boundaries
for piece in _split_at_paragraphs(body, max_tokens):
if piece.strip():
metadata = {}
if section_header:
metadata["section_header"] = section_header
chunks.append({
"text": piece.strip(),
"chunk_index": chunk_index,
"metadata": metadata,
})
chunk_index += 1
return chunks
+30
View File
@@ -0,0 +1,30 @@
"""Chunking pipeline for plain-text notes."""
from __future__ import annotations
def auto_title(text: str) -> str:
"""Derive a short title from the first line of text.
Strips leading markdown ``#`` characters and trims to 80 characters.
"""
first_line = text.split("\n", 1)[0].strip()
first_line = first_line.lstrip("#").strip()
if len(first_line) > 80:
first_line = first_line[:80].rstrip()
return first_line
def chunk_note(text: str) -> list[dict]:
"""Return the entire text as a single chunk.
Returns a list with one chunk dict containing:
text, chunk_index, metadata
"""
return [
{
"text": text.strip(),
"chunk_index": 0,
"metadata": {},
}
]
+1
View File
@@ -0,0 +1 @@
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth
+38
View File
@@ -0,0 +1,38 @@
"""API key authentication middleware."""
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from main import app
from kb.config import cfg
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip auth if no API key is configured
if cfg.api_key is None:
return await call_next(request)
# Always allow health checks without auth
if request.url.path == "/api/v1/health":
return await call_next(request)
auth_header = request.headers.get("Authorization")
if not auth_header:
return JSONResponse(
{"detail": "Missing Authorization header."},
status_code=401,
)
expected = f"Bearer {cfg.api_key}"
if auth_header != expected:
return JSONResponse(
{"detail": "Invalid API key."},
status_code=401,
)
return await call_next(request)
app.add_middleware(AuthMiddleware)
+143
View File
@@ -0,0 +1,143 @@
"""Document management endpoints — list, view, and delete documents."""
import json
from typing import Optional
from fastapi import HTTPException, Query
from main import app
from kb.config import cfg
from kb.database import get_connection
@app.get("/api/v1/documents")
async def list_documents(
type: Optional[str] = Query(None),
tags: Optional[str] = Query(None),
):
conn = get_connection(cfg.db_path)
try:
sql = """
SELECT d.id, d.title, d.doc_type,
(SELECT COUNT(*) FROM chunks c WHERE c.document_id = d.id) AS chunk_count,
d.created_at
FROM documents d
"""
joins: list[str] = []
where: list[str] = []
params: list = []
if type:
where.append("d.doc_type = ?")
params.append(type)
if tags:
tag_list = [t.strip() for t in tags.split(",") if t.strip()]
for i, tag in enumerate(tag_list):
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
where.append(f"t{i}.name = ?")
params.append(tag)
if joins:
sql += " " + " ".join(joins)
if where:
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY d.created_at DESC"
rows = conn.execute(sql, params).fetchall()
results = []
for row in rows:
doc_id = row["id"]
tag_rows = conn.execute(
"""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""",
(doc_id,),
).fetchall()
results.append({
"id": row["id"],
"title": row["title"],
"doc_type": row["doc_type"],
"tags": [t["name"] for t in tag_rows],
"chunk_count": row["chunk_count"],
"created_at": row["created_at"],
})
return results
finally:
conn.close()
@app.get("/api/v1/documents/{doc_id}")
async def get_document(doc_id: int):
conn = get_connection(cfg.db_path)
try:
doc = conn.execute(
"SELECT * FROM documents WHERE id = ?", (doc_id,)
).fetchone()
if not doc:
raise HTTPException(status_code=404, detail="Document not found.")
chunks = conn.execute(
"SELECT * FROM chunks WHERE document_id = ? ORDER BY chunk_index",
(doc_id,),
).fetchall()
tag_rows = conn.execute(
"""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""",
(doc_id,),
).fetchall()
return {
**dict(doc),
"tags": [t["name"] for t in tag_rows],
"chunks": [dict(c) for c in chunks],
}
finally:
conn.close()
@app.delete("/api/v1/documents/{doc_id}")
async def delete_document(doc_id: int):
conn = get_connection(cfg.db_path)
try:
doc = conn.execute(
"SELECT id, title FROM documents WHERE id = ?", (doc_id,)
).fetchone()
if not doc:
raise HTTPException(status_code=404, detail="Document not found.")
# Get chunk IDs for embedding cleanup
chunk_ids = conn.execute(
"SELECT id FROM chunks WHERE document_id = ?", (doc_id,)
).fetchall()
# Delete embeddings from vec table
for row in chunk_ids:
conn.execute(
"DELETE FROM chunks_vec WHERE chunk_id = ?", (row["id"],)
)
# Delete document (cascades to chunks, document_tags)
conn.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
conn.commit()
return {
"status": "deleted",
"document_id": doc_id,
"title": doc["title"],
}
finally:
conn.close()
+12
View File
@@ -0,0 +1,12 @@
"""Health check endpoint."""
import main
from main import app
@app.get("/api/v1/health")
async def health():
if not main.ready:
from fastapi.responses import JSONResponse
return JSONResponse({"status": "starting"}, status_code=503)
return {"status": "healthy"}
+66
View File
@@ -0,0 +1,66 @@
"""Job management endpoints — submit files/notes for ingestion and track progress."""
import json
from typing import Optional
from fastapi import HTTPException, UploadFile, File, Form, Query
from main import app
from kb.config import cfg
from kb.database import get_connection, create_job, get_job, list_jobs
from kb.staging import stage_file, stage_note
@app.post("/api/v1/jobs", status_code=202)
async def submit_job(
file: Optional[UploadFile] = File(None),
note: Optional[str] = Form(None),
title: Optional[str] = Form(None),
tags: Optional[str] = Form(None),
doc_type: Optional[str] = Form(None),
):
if not file and not note:
raise HTTPException(
status_code=422,
detail="Either 'file' or 'note' must be provided.",
)
if file:
content = await file.read()
staging_path = stage_file(cfg.staging_dir, file.filename, content)
filename = file.filename
else:
staging_path = stage_note(cfg.staging_dir, title or "note", note)
filename = staging_path.name
tags_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else []
tags_json = json.dumps(tags_list)
conn = get_connection(cfg.db_path)
try:
job_id = create_job(conn, filename, str(staging_path), doc_type, tags_json, title)
return {"job_id": job_id, "status": "queued", "filename": filename}
finally:
conn.close()
@app.get("/api/v1/jobs")
async def list_all_jobs(status: Optional[str] = Query(None)):
conn = get_connection(cfg.db_path)
try:
rows = list_jobs(conn, status)
return [dict(row) for row in rows]
finally:
conn.close()
@app.get("/api/v1/jobs/{job_id}")
async def get_single_job(job_id: int):
conn = get_connection(cfg.db_path)
try:
row = get_job(conn, job_id)
if not row:
raise HTTPException(status_code=404, detail="Job not found.")
return dict(row)
finally:
conn.close()
+55
View File
@@ -0,0 +1,55 @@
"""Reindex endpoint — re-embed all chunks with the current model."""
import logging
import struct
from main import app
from kb.config import cfg
from kb.database import get_connection, recreate_vec_table
from kb.embeddings import embed_texts, get_model_dim
logger = logging.getLogger("kb.routes.reindex")
BATCH_SIZE = 256
@app.post("/api/v1/reindex")
async def reindex():
dim = get_model_dim()
conn = get_connection(cfg.db_path)
try:
# Fetch all chunks
rows = conn.execute("SELECT id, text FROM chunks ORDER BY id").fetchall()
chunk_ids = [row["id"] for row in rows]
chunk_texts = [row["text"] for row in rows]
logger.info("Reindexing %d chunks with model '%s'", len(chunk_ids), cfg.model)
# Recreate the vec table
recreate_vec_table(conn, dim)
# Embed and insert in batches
for i in range(0, len(chunk_ids), BATCH_SIZE):
batch_ids = chunk_ids[i : i + BATCH_SIZE]
batch_texts = chunk_texts[i : i + BATCH_SIZE]
embeddings = embed_texts(batch_texts)
for chunk_id, embedding in zip(batch_ids, embeddings):
blob = struct.pack(f"{len(embedding)}f", *embedding)
conn.execute(
"INSERT INTO chunks_vec(embedding, chunk_id) VALUES (?, ?)",
(blob, chunk_id),
)
conn.commit()
logger.info("Reindex complete: %d chunks", len(chunk_ids))
return {
"chunks_reindexed": len(chunk_ids),
"model": cfg.model,
}
finally:
conn.close()
+43
View File
@@ -0,0 +1,43 @@
"""Search endpoint — hybrid FTS5 + vector search."""
from typing import Optional
from fastapi import HTTPException
from pydantic import BaseModel
from main import app
from kb.config import cfg
from kb.database import get_connection
from kb.search import hybrid_search
class SearchRequest(BaseModel):
query: str
top: int = 10
tags: Optional[list[str]] = None
doc_type: Optional[str] = None
fts_only: bool = False
vec_only: bool = False
threshold: Optional[float] = None
@app.post("/api/v1/search")
async def search(req: SearchRequest):
conn = get_connection(cfg.db_path)
try:
result = hybrid_search(
conn,
req.query,
cfg,
top=req.top,
tags=req.tags,
doc_type=req.doc_type,
fts_only=req.fts_only,
vec_only=req.vec_only,
threshold=req.threshold,
)
return result
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
conn.close()
+67
View File
@@ -0,0 +1,67 @@
"""System status endpoint — model info, DB stats, and queue stats."""
import os
from main import app, __version__
from kb.config import cfg
from kb.database import get_connection
from kb.embeddings import get_model_dim
@app.get("/api/v1/status")
async def status():
# Device info
device_name = cfg.device
try:
import torch
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
except Exception:
pass
conn = get_connection(cfg.db_path)
try:
# Document counts by type
type_rows = conn.execute(
"SELECT doc_type, COUNT(*) AS count FROM documents GROUP BY doc_type"
).fetchall()
doc_counts = {row["doc_type"]: row["count"] for row in type_rows}
# Total chunks
total_chunks = conn.execute("SELECT COUNT(*) AS n FROM chunks").fetchone()["n"]
# DB file size
db_size = 0
try:
db_size = os.path.getsize(cfg.db_path)
except OSError:
pass
# Queue stats
queue_rows = conn.execute(
"""
SELECT status, COUNT(*) AS count
FROM jobs
WHERE status IN ('queued', 'processing')
GROUP BY status
"""
).fetchall()
queue_stats = {row["status"]: row["count"] for row in queue_rows}
return {
"version": __version__,
"model_name": cfg.model,
"embedding_dim": get_model_dim(),
"device": device_name,
"db": {
"documents_by_type": doc_counts,
"total_chunks": total_chunks,
"db_size_bytes": db_size,
},
"queue": {
"queued": queue_stats.get("queued", 0),
"processing": queue_stats.get("processing", 0),
},
}
finally:
conn.close()
+63
View File
@@ -0,0 +1,63 @@
"""Tag management endpoints."""
from typing import Optional
from fastapi import HTTPException
from pydantic import BaseModel
from main import app
from kb.config import cfg
from kb.database import get_connection, tag_document, untag_document
@app.get("/api/v1/tags")
async def list_tags():
conn = get_connection(cfg.db_path)
try:
rows = conn.execute(
"""
SELECT t.name, COUNT(dt.document_id) AS count
FROM tags t
LEFT JOIN document_tags dt ON t.id = dt.tag_id
GROUP BY t.id, t.name
ORDER BY t.name
"""
).fetchall()
return [{"name": row["name"], "count": row["count"]} for row in rows]
finally:
conn.close()
class TagUpdateRequest(BaseModel):
add: Optional[list[str]] = None
remove: Optional[list[str]] = None
@app.put("/api/v1/documents/{doc_id}/tags")
async def update_document_tags(doc_id: int, req: TagUpdateRequest):
conn = get_connection(cfg.db_path)
try:
doc = conn.execute(
"SELECT id FROM documents WHERE id = ?", (doc_id,)
).fetchone()
if not doc:
raise HTTPException(status_code=404, detail="Document not found.")
if req.add:
tag_document(conn, doc_id, req.add)
if req.remove:
untag_document(conn, doc_id, req.remove)
tag_rows = conn.execute(
"""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""",
(doc_id,),
).fetchall()
return {"document_id": doc_id, "tags": [t["name"] for t in tag_rows]}
finally:
conn.close()
+290
View File
@@ -0,0 +1,290 @@
"""Hybrid search — FTS5 + sqlite-vec with Reciprocal Rank Fusion."""
import json
import struct
import sqlite3
def hybrid_search(
conn: sqlite3.Connection,
query: str,
cfg,
top: int = 10,
tags: list[str] | None = None,
doc_type: str | None = None,
fts_only: bool = False,
vec_only: bool = False,
threshold: float | None = None,
) -> dict:
"""Run hybrid search and return merged, enriched results.
Args:
conn: SQLite connection (with row_factory = sqlite3.Row).
query: User search query string.
cfg: Config object with ``model`` and ``device`` attributes.
top: Maximum number of results to return.
tags: Optional tag filter documents must have *all* listed tags.
doc_type: Optional document-type filter.
fts_only: Only use FTS5 (skip vector search).
vec_only: Only use vector search (skip FTS5).
threshold: Optional minimum score; results below are dropped.
Returns:
Dict with keys: query, results, total_matches, returned.
"""
candidate_count = top * 3
fts_results: dict[int, float] = {}
vec_results: dict[int, float] = {}
if not vec_only:
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
if not fts_only:
vec_results = _vector_search(conn, query, candidate_count, tags, doc_type)
# --- merge ---------------------------------------------------------------
if fts_only:
merged = sorted(fts_results.items(), key=lambda x: x[1], reverse=True)
elif vec_only:
merged = sorted(vec_results.items(), key=lambda x: x[1], reverse=True)
else:
merged = _rrf_merge(fts_results, vec_results)
# Apply threshold filter — use config default if not specified per-query
effective_threshold = threshold if threshold is not None else cfg.search_threshold
if effective_threshold > 0:
merged = [(cid, score) for cid, score in merged if score >= effective_threshold]
total_matches = len(merged)
merged = merged[:top]
# --- enrich --------------------------------------------------------------
results = _enrich(conn, merged)
return {
"query": query,
"results": results,
"total_matches": total_matches,
"returned": len(results),
}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _fts_search(
conn: sqlite3.Connection,
query: str,
limit: int,
tags: list[str] | None,
doc_type: str | None,
) -> dict[int, float]:
"""FTS5 search on ``chunks_fts``.
Returns:
{chunk_id: bm25_score} where scores are positive (higher = better).
"""
sql = "SELECT f.rowid AS chunk_id, bm25(chunks_fts) AS rank FROM chunks_fts f"
joins: list[str] = []
where: list[str] = ["chunks_fts MATCH ?"]
params: list = [query]
if tags or doc_type:
joins.append("JOIN chunks c ON f.rowid = c.id")
joins.append("JOIN documents d ON c.document_id = d.id")
if doc_type:
where.append("d.doc_type = ?")
params.append(doc_type)
if tags:
for i, tag in enumerate(tags):
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
where.append(f"t{i}.name = ?")
params.append(tag.strip().lower())
sql += " " + " ".join(joins)
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY rank LIMIT ?"
params.append(limit)
rows = conn.execute(sql, params).fetchall()
# BM25 returns negative values (lower = better match); negate so
# higher = better.
return {row[0]: -row[1] for row in rows}
def _vector_search(
conn: sqlite3.Connection,
query: str,
limit: int,
tags: list[str] | None,
doc_type: str | None,
) -> dict[int, float]:
"""Embed *query* and search ``chunks_vec`` via sqlite-vec.
Returns:
{chunk_id: similarity} where similarity = 1 / (1 + distance).
"""
from kb.embeddings import embed_texts
query_embedding = embed_texts([query])[0]
blob = struct.pack(f"{len(query_embedding)}f", *query_embedding)
rows = conn.execute(
"""
SELECT chunk_id, distance
FROM chunks_vec
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?
""",
(blob, limit),
).fetchall()
results: dict[int, float] = {}
for row in rows:
chunk_id = row[0]
distance = row[1]
similarity = 1.0 / (1.0 + distance)
# Post-hoc tag / doc_type filtering for vector results
if tags or doc_type:
if not _passes_filters(conn, chunk_id, tags, doc_type):
continue
results[chunk_id] = similarity
return results
def _passes_filters(
conn: sqlite3.Connection,
chunk_id: int,
tags: list[str] | None,
doc_type: str | None,
) -> bool:
"""Return True if chunk passes tag and doc_type filters."""
sql = """
SELECT d.id FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
"""
params: list = [chunk_id]
if doc_type:
sql += " AND d.doc_type = ?"
params.append(doc_type)
doc_row = conn.execute(sql, params).fetchone()
if not doc_row:
return False
if tags:
doc_id = doc_row[0]
placeholders = ",".join("?" * len(tags))
normalised = [t.strip().lower() for t in tags]
count = conn.execute(
f"""
SELECT COUNT(DISTINCT t.name) FROM document_tags dt
JOIN tags t ON dt.tag_id = t.id
WHERE dt.document_id = ? AND t.name IN ({placeholders})
""",
[doc_id, *normalised],
).fetchone()[0]
if count < len(tags):
return False
return True
def _rrf_merge(
fts_results: dict[int, float],
vec_results: dict[int, float],
k: int = 60,
) -> list[tuple[int, float]]:
"""Reciprocal Rank Fusion over two scored result sets.
Each set is ranked independently (highest score first, rank starts at 1).
RRF score for a document = sum of 1/(k + rank) across sets it appears in.
Returns:
Sorted list of (chunk_id, rrf_score), highest first.
"""
fts_ranked = _rank_by_score(fts_results)
vec_ranked = _rank_by_score(vec_results)
all_ids = set(fts_ranked) | set(vec_ranked)
scores: list[tuple[int, float]] = []
for chunk_id in all_ids:
rrf = 0.0
if chunk_id in fts_ranked:
rrf += 1.0 / (k + fts_ranked[chunk_id])
if chunk_id in vec_ranked:
rrf += 1.0 / (k + vec_ranked[chunk_id])
scores.append((chunk_id, rrf))
scores.sort(key=lambda x: x[1], reverse=True)
return scores
def _rank_by_score(results: dict[int, float]) -> dict[int, int]:
"""Return {id: 1-based rank} sorted by score descending."""
ordered = sorted(results, key=results.get, reverse=True)
return {cid: rank for rank, cid in enumerate(ordered, start=1)}
def _enrich(
conn: sqlite3.Connection,
merged: list[tuple[int, float]],
) -> list[dict]:
"""Fetch chunk text, document metadata, chunk metadata, and tags."""
results: list[dict] = []
for chunk_id, score in merged:
row = conn.execute(
"""
SELECT c.id, c.text, c.chunk_index, c.metadata AS chunk_meta,
d.id AS doc_id, d.title, d.doc_type, d.source_path,
d.created_at
FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
""",
(chunk_id,),
).fetchone()
if row is None:
continue
chunk_meta = json.loads(row[3]) if row[3] else {}
tag_rows = conn.execute(
"""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""",
(row[4],), # doc_id
).fetchall()
results.append({
"chunk_id": row[0],
"score": round(score, 6),
"text": row[1],
"chunk_index": row[2],
"chunk_metadata": chunk_meta,
"title": row[5],
"doc_type": row[6],
"source_path": row[7],
"created_at": row[8],
"tags": [t[0] for t in tag_rows],
})
return results
+47
View File
@@ -0,0 +1,47 @@
"""Staging area for files awaiting ingestion."""
import logging
import uuid
from pathlib import Path
logger = logging.getLogger("kb.staging")
def stage_file(staging_dir: Path, filename: str, content: bytes) -> Path:
"""Write raw bytes to a uniquely-named file in the staging directory.
The staged file is named ``{uuid}_{filename}`` to avoid collisions.
Returns:
The path to the newly created staged file.
"""
staging_dir.mkdir(parents=True, exist_ok=True)
dest = staging_dir / f"{uuid.uuid4()}_{filename}"
dest.write_bytes(content)
logger.debug("Staged file: %s (%d bytes)", dest, len(content))
return dest
def stage_note(staging_dir: Path, title: str, text: str) -> Path:
"""Write a text note to the staging directory.
The staged file is named ``{uuid}_{title}.note``.
Returns:
The path to the newly created staged note file.
"""
staging_dir.mkdir(parents=True, exist_ok=True)
dest = staging_dir / f"{uuid.uuid4()}_{title}.note"
dest.write_text(text, encoding="utf-8")
logger.debug("Staged note: %s (%d chars)", dest, len(text))
return dest
def cleanup(path: Path) -> None:
"""Delete a staged file if it exists. Logs a warning on failure."""
try:
if path.exists():
path.unlink()
logger.debug("Cleaned up staged file: %s", path)
except OSError as exc:
logger.warning("Failed to clean up staged file %s: %s", path, exc)
+175
View File
@@ -0,0 +1,175 @@
"""Async background worker for processing ingestion jobs."""
import asyncio
import hashlib
import json
import logging
from pathlib import Path
from kb import config, database, embeddings, staging
from kb.ingest import detector
logger = logging.getLogger("kb.worker")
async def ingestion_worker() -> None:
"""Main background loop that processes queued ingestion jobs.
Runs indefinitely until cancelled. Every 2 seconds it checks for the
oldest queued job, marks it as *processing*, and delegates to
:func:`_process_job` in a thread pool.
"""
logger.info("Ingestion worker started")
while True:
try:
cfg = config.cfg
conn = database.get_connection(cfg.db_path)
try:
row = conn.execute(
"SELECT * FROM jobs WHERE status = 'queued' "
"ORDER BY created_at ASC LIMIT 1"
).fetchone()
if row is None:
await asyncio.sleep(2)
continue
job_id = row["id"]
database.update_job_status(conn, job_id, "processing")
logger.info("Processing job %d (%s)", job_id, row["filename"])
finally:
conn.close()
try:
status, doc_id, chunk_count = await asyncio.to_thread(
_process_job, row
)
except Exception as exc:
logger.exception("Job %d failed", job_id)
conn = database.get_connection(cfg.db_path)
try:
database.update_job_status(
conn, job_id, "failed", error=str(exc)
)
finally:
conn.close()
continue
conn = database.get_connection(cfg.db_path)
try:
database.update_job_status(
conn,
job_id,
status,
document_id=doc_id,
chunk_count=chunk_count,
)
finally:
conn.close()
logger.info(
"Job %d finished: status=%s doc_id=%s chunks=%s",
job_id, status, doc_id, chunk_count,
)
except asyncio.CancelledError:
logger.info("Ingestion worker cancelled")
raise
except Exception:
logger.exception("Unexpected error in ingestion worker loop")
await asyncio.sleep(2)
def _process_job(job_row) -> tuple[str, int | None, int]:
"""Synchronously process a single ingestion job.
Returns:
A tuple of ``(status, document_id, chunk_count)`` where *status* is
one of ``"done"``, ``"skipped"``.
"""
cfg = config.cfg
conn = database.get_connection(cfg.db_path)
staged_path = Path(job_row["staging_path"])
try:
# --- Determine document type and language -------------------------
filename = job_row["filename"]
forced_type = job_row["doc_type"]
if staged_path.suffix == ".note":
doc_type = "note"
language = None
elif forced_type:
doc_type = forced_type
language = None
else:
doc_type, language = detector.detect_type(Path(filename))
# --- Chunk the content --------------------------------------------
if doc_type == "note":
text = staged_path.read_text(encoding="utf-8")
from kb.ingest.note import chunk_note
chunks = chunk_note(text)
elif doc_type == "pdf":
from kb.ingest.docling_pipeline import chunk_document
chunks = chunk_document(staged_path, cfg.ingest_device)
elif doc_type == "markdown":
text = staged_path.read_text(encoding="utf-8")
from kb.ingest.markdown import chunk_markdown
chunks = chunk_markdown(text)
elif doc_type == "code":
text = staged_path.read_text(encoding="utf-8")
if not language:
_, language = detector.detect_type(Path(filename))
from kb.ingest.code import chunk_code
chunks = chunk_code(text, language)
else:
raise ValueError(f"Unsupported doc_type: {doc_type}")
# --- Duplicate detection via content hash -------------------------
raw_bytes = staged_path.read_bytes()
content_hash = hashlib.sha256(raw_bytes).hexdigest()
if database.hash_exists(conn, content_hash):
logger.info("Duplicate detected for job %d, skipping", job_row["id"])
return ("skipped", None, 0)
# --- Persist document, chunks, and embeddings ---------------------
title = job_row["title"] or filename
doc_id = database.insert_document(
conn,
title=title,
source_path=str(staged_path),
content_hash=content_hash,
doc_type=doc_type,
language=language,
)
chunk_texts = [c if isinstance(c, str) else c["text"] for c in chunks]
vectors = embeddings.embed_texts(chunk_texts)
for idx, (chunk_text, vector) in enumerate(zip(chunk_texts, vectors)):
metadata = None
if not isinstance(chunks[idx], str):
metadata = {
k: v for k, v in chunks[idx].items() if k != "text"
} or None
chunk_id = database.insert_chunk(
conn,
document_id=doc_id,
chunk_index=idx,
text=chunk_text,
metadata=metadata,
)
database.insert_embedding(conn, chunk_id, vector)
# --- Tags ---------------------------------------------------------
tags = json.loads(job_row["tags_json"] or "[]")
if tags:
database.tag_document(conn, doc_id, tags)
conn.commit()
return ("done", doc_id, len(chunk_texts))
finally:
conn.close()
staging.cleanup(staged_path)
+69
View File
@@ -0,0 +1,69 @@
"""Engine entry point — FastAPI server with eager model loading."""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
_version_file = Path(__file__).parent / "VERSION"
__version__ = _version_file.read_text().strip() if _version_file.exists() else "dev"
from kb.config import cfg
from kb.embeddings import load_model
from kb.database import get_connection, init_schema
from kb.worker import ingestion_worker
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("kb.engine")
# Track readiness for health endpoint
ready = False
worker_task: asyncio.Task | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global ready, worker_task
# Set HF cache before any model imports
os.environ["HF_HOME"] = str(cfg.hf_cache)
log.info("Starting engine...")
cfg.ensure_dirs()
# Initialise database
conn = get_connection(cfg.db_path)
model_dim = load_model(cfg.model, cfg.device)
init_schema(conn, model_dim)
conn.close()
# Start background ingestion worker
worker_task = asyncio.create_task(ingestion_worker())
ready = True
log.info("Engine ready — model: %s, device: %s", cfg.model, cfg.device)
yield
# Shutdown
ready = False
if worker_task:
worker_task.cancel()
try:
await worker_task
except asyncio.CancelledError:
pass
log.info("Engine stopped.")
app = FastAPI(title="kb-engine", version=__version__, lifespan=lifespan)
# Import routes after app is created
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth # noqa: E402, F401
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host=cfg.host, port=cfg.port, log_level="info")
+35
View File
@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools>=68.0"]
build-backend = "setuptools.build_meta"
[project]
name = "kb-engine"
version = "2.0.0"
description = "Knowledge base engine — FastAPI server with hybrid search and async ingestion"
requires-python = ">=3.11"
license = "MIT"
dependencies = [
"fastapi>=0.115",
"uvicorn[standard]>=0.30",
"python-multipart>=0.0.9",
"click>=8.1",
"pyyaml>=6.0",
"sentence-transformers>=3.0",
"sqlite-vec>=0.1.1",
"docling>=2.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.24",
"httpx>=0.27",
]
[tool.setuptools.packages.find]
where = ["."]
include = ["kb*"]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"