v2 restructure: Go client, Docker engine, release tooling
- Remove v1 Python CLI (src/kb_search/, tests/, root pyproject.toml, uv.lock, .venv) - Add Go client with cross-platform build (client/) - Add FastAPI engine with NVIDIA and multi-stage ROCm Dockerfiles (engine/) - Add VERSION files for client and engine, wired into builds - Add release.sh for automated build, tag, release, and Docker push - Update README with build/release docs and ROCm migration note - Clean up .gitignore for v2 project structure Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
*.egg-info/
|
||||
.pytest_cache/
|
||||
tests/
|
||||
@@ -0,0 +1,35 @@
|
||||
FROM nvidia/cuda:13.0.1-runtime-ubuntu24.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 python3.12-venv python3.12-dev python3-pip \
|
||||
libpoppler-cpp-dev poppler-utils \
|
||||
libgl1 libglib2.0-0 \
|
||||
build-essential curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml ./
|
||||
COPY kb/ kb/
|
||||
COPY main.py ./
|
||||
COPY VERSION ./
|
||||
|
||||
RUN uv venv .venv && \
|
||||
. .venv/bin/activate && \
|
||||
uv pip install -e . && \
|
||||
uv pip install --no-deps onnxruntime-gpu
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/app/.venv"
|
||||
ENV KB_DEVICE=auto
|
||||
ENV KB_INGEST_DEVICE=auto
|
||||
ENV KB_DATA_DIR=/data
|
||||
|
||||
EXPOSE 8000
|
||||
VOLUME ["/data"]
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,68 @@
|
||||
# Stage 1: Build — install Python deps with dev tools available
|
||||
FROM rocm/dev-ubuntu-24.04:6.4-complete AS builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 python3.12-venv python3.12-dev python3-pip \
|
||||
libpoppler-cpp-dev poppler-utils \
|
||||
build-essential curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml ./
|
||||
COPY kb/ kb/
|
||||
COPY main.py ./
|
||||
COPY VERSION ./
|
||||
|
||||
RUN uv venv .venv && \
|
||||
. .venv/bin/activate && \
|
||||
uv pip install -e . && \
|
||||
uv pip install --no-deps onnxruntime-rocm
|
||||
|
||||
# Stage 2: Runtime — minimal ROCm runtime libs only
|
||||
FROM ubuntu:24.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Add ROCm apt repository
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates curl gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key \
|
||||
| gpg --dearmor -o /etc/apt/keyrings/rocm.gpg \
|
||||
&& echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/6.4.1 noble main" \
|
||||
> /etc/apt/sources.list.d/rocm.list \
|
||||
&& printf 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600\n' \
|
||||
> /etc/apt/preferences.d/rocm-pin-600 \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 python3.12-venv \
|
||||
libpoppler-cpp0t64 poppler-utils \
|
||||
libgl1 libglib2.0-0 \
|
||||
rocm-hip-runtime \
|
||||
rocm-hip-libraries \
|
||||
miopen-hip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy built venv and application from builder
|
||||
COPY --from=builder /app/.venv .venv
|
||||
COPY --from=builder /app/kb kb
|
||||
COPY --from=builder /app/main.py .
|
||||
COPY --from=builder /app/pyproject.toml .
|
||||
COPY --from=builder /app/VERSION .
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/app/.venv"
|
||||
ENV KB_DEVICE=auto
|
||||
ENV KB_INGEST_DEVICE=auto
|
||||
ENV KB_DATA_DIR=/data
|
||||
|
||||
EXPOSE 8000
|
||||
VOLUME ["/data"]
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1 @@
|
||||
2.0.3
|
||||
@@ -0,0 +1,24 @@
|
||||
services:
|
||||
kb-engine:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.nvidia
|
||||
runtime: nvidia
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
ports:
|
||||
- "${KB_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- ${KB_DATA_PATH:-./data}:/data
|
||||
environment:
|
||||
- KB_MODEL=${KB_MODEL:-all-MiniLM-L6-v2}
|
||||
- KB_DEVICE=${KB_DEVICE:-auto}
|
||||
- KB_INGEST_DEVICE=${KB_INGEST_DEVICE:-auto}
|
||||
- KB_API_KEY=${KB_API_KEY:-}
|
||||
- KB_SEARCH_THRESHOLD=${KB_SEARCH_THRESHOLD:-0.01}
|
||||
restart: unless-stopped
|
||||
@@ -0,0 +1,21 @@
|
||||
services:
|
||||
kb-engine:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.rocm
|
||||
devices:
|
||||
- "/dev/kfd"
|
||||
- "/dev/dri"
|
||||
group_add:
|
||||
- "video"
|
||||
ports:
|
||||
- "${KB_PORT:-8000}:8000"
|
||||
volumes:
|
||||
- ${KB_DATA_PATH:-./data}:/data
|
||||
environment:
|
||||
- KB_MODEL=${KB_MODEL:-all-MiniLM-L6-v2}
|
||||
- KB_DEVICE=${KB_DEVICE:-auto}
|
||||
- KB_INGEST_DEVICE=${KB_INGEST_DEVICE:-auto}
|
||||
- KB_API_KEY=${KB_API_KEY:-}
|
||||
- KB_SEARCH_THRESHOLD=${KB_SEARCH_THRESHOLD:-0.01}
|
||||
restart: unless-stopped
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Engine configuration — all settings from environment variables."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Config:
|
||||
data_dir: Path
|
||||
model: str
|
||||
device: str
|
||||
ingest_device: str
|
||||
api_key: str | None
|
||||
host: str
|
||||
port: int
|
||||
|
||||
def __init__(self):
|
||||
self.data_dir = Path(os.environ.get("KB_DATA_DIR", "/data"))
|
||||
self.model = os.environ.get("KB_MODEL", "all-MiniLM-L6-v2")
|
||||
self.device = os.environ.get("KB_DEVICE", "auto")
|
||||
self.ingest_device = os.environ.get("KB_INGEST_DEVICE", "auto")
|
||||
self.api_key = os.environ.get("KB_API_KEY") or None
|
||||
self.search_threshold = float(os.environ.get("KB_SEARCH_THRESHOLD", "0.01"))
|
||||
self.host = os.environ.get("KB_HOST", "0.0.0.0")
|
||||
self.port = int(os.environ.get("KB_PORT", "8000"))
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
return self.data_dir / "kb.db"
|
||||
|
||||
@property
|
||||
def hf_cache(self) -> Path:
|
||||
return self.data_dir / "hf_cache"
|
||||
|
||||
@property
|
||||
def staging_dir(self) -> Path:
|
||||
return self.data_dir / "staging"
|
||||
|
||||
def ensure_dirs(self):
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.hf_cache.mkdir(exist_ok=True)
|
||||
self.staging_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
cfg = Config()
|
||||
@@ -0,0 +1,308 @@
|
||||
"""Database module for kb-engine v2.
|
||||
|
||||
Provides SQLite database access with WAL mode, FTS5 full-text search,
|
||||
and sqlite-vec vector storage for embeddings.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import struct
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def get_connection(db_path: str) -> sqlite3.Connection:
|
||||
"""Return a sqlite3 connection with WAL mode, Row factory, and foreign keys enabled."""
|
||||
import sqlite_vec
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
|
||||
def init_schema(conn: sqlite3.Connection, embedding_dim: int) -> None:
|
||||
"""Create all tables if they do not already exist."""
|
||||
conn.executescript(f"""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id INTEGER PRIMARY KEY,
|
||||
title TEXT,
|
||||
source_path TEXT,
|
||||
content_hash TEXT UNIQUE,
|
||||
doc_type TEXT,
|
||||
language TEXT,
|
||||
created_at TEXT DEFAULT current_timestamp
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id INTEGER PRIMARY KEY,
|
||||
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER,
|
||||
text TEXT,
|
||||
token_count INTEGER,
|
||||
metadata TEXT DEFAULT '{{}}',
|
||||
UNIQUE(document_id, chunk_index)
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
content=chunks,
|
||||
content_rowid=id
|
||||
);
|
||||
|
||||
-- Triggers to keep FTS index in sync with chunks table
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES ('delete', old.id, old.text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES ('delete', old.id, old.text);
|
||||
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
|
||||
END;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT UNIQUE COLLATE NOCASE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS document_tags (
|
||||
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
|
||||
tag_id INTEGER REFERENCES tags(id) ON DELETE CASCADE,
|
||||
UNIQUE(document_id, tag_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id INTEGER PRIMARY KEY,
|
||||
filename TEXT,
|
||||
status TEXT DEFAULT 'queued' CHECK(status IN ('queued','processing','done','failed','skipped')),
|
||||
doc_type TEXT,
|
||||
tags_json TEXT DEFAULT '[]',
|
||||
title TEXT,
|
||||
error TEXT,
|
||||
document_id INTEGER,
|
||||
chunk_count INTEGER DEFAULT 0,
|
||||
staging_path TEXT,
|
||||
created_at TEXT DEFAULT current_timestamp,
|
||||
completed_at TEXT
|
||||
);
|
||||
""")
|
||||
|
||||
# sqlite-vec virtual table (cannot use IF NOT EXISTS with vec0, so check first)
|
||||
existing = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks_vec'"
|
||||
).fetchone()
|
||||
if not existing:
|
||||
conn.execute(
|
||||
f"CREATE VIRTUAL TABLE chunks_vec USING vec0(embedding float[{embedding_dim}], chunk_id integer)"
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_db_config(conn: sqlite3.Connection, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Retrieve a value from the config table."""
|
||||
row = conn.execute("SELECT value FROM config WHERE key = ?", (key,)).fetchone()
|
||||
return row["value"] if row else default
|
||||
|
||||
|
||||
def set_db_config(conn: sqlite3.Connection, key: str, value: str) -> None:
|
||||
"""Insert or update a value in the config table."""
|
||||
conn.execute(
|
||||
"INSERT INTO config(key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
(key, value),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def hash_exists(conn: sqlite3.Connection, content_hash: str) -> bool:
|
||||
"""Check whether a document with the given content hash is already ingested."""
|
||||
row = conn.execute(
|
||||
"SELECT 1 FROM documents WHERE content_hash = ?", (content_hash,)
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
def insert_document(
|
||||
conn: sqlite3.Connection,
|
||||
title: str,
|
||||
source_path: str,
|
||||
content_hash: str,
|
||||
doc_type: str,
|
||||
language: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Insert a new document and return its id."""
|
||||
cur = conn.execute(
|
||||
"INSERT INTO documents(title, source_path, content_hash, doc_type, language) VALUES (?, ?, ?, ?, ?)",
|
||||
(title, source_path, content_hash, doc_type, language),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk / embedding helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def insert_chunk(
|
||||
conn: sqlite3.Connection,
|
||||
document_id: int,
|
||||
chunk_index: int,
|
||||
text: str,
|
||||
token_count: Optional[int] = None,
|
||||
metadata: Any = None,
|
||||
) -> int:
|
||||
"""Insert a chunk and return its id. *metadata* is JSON-serialized if it is a dict."""
|
||||
if metadata is None:
|
||||
metadata_str = "{}"
|
||||
elif isinstance(metadata, dict):
|
||||
metadata_str = json.dumps(metadata)
|
||||
else:
|
||||
metadata_str = str(metadata)
|
||||
|
||||
cur = conn.execute(
|
||||
"INSERT INTO chunks(document_id, chunk_index, text, token_count, metadata) VALUES (?, ?, ?, ?, ?)",
|
||||
(document_id, chunk_index, text, token_count, metadata_str),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def insert_embedding(conn: sqlite3.Connection, chunk_id: int, embedding: list[float]) -> None:
|
||||
"""Insert an embedding vector into chunks_vec using struct-packed floats."""
|
||||
blob = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
conn.execute(
|
||||
"INSERT INTO chunks_vec(embedding, chunk_id) VALUES (?, ?)",
|
||||
(blob, chunk_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tagging helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def tag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]) -> None:
|
||||
"""Create tags if needed and associate them with a document."""
|
||||
for name in tag_names:
|
||||
conn.execute("INSERT OR IGNORE INTO tags(name) VALUES (?)", (name,))
|
||||
tag_id = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()["id"]
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO document_tags(document_id, tag_id) VALUES (?, ?)",
|
||||
(document_id, tag_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def untag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]) -> None:
|
||||
"""Remove tag associations from a document."""
|
||||
for name in tag_names:
|
||||
row = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()
|
||||
if row:
|
||||
conn.execute(
|
||||
"DELETE FROM document_tags WHERE document_id = ? AND tag_id = ?",
|
||||
(document_id, row["id"]),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vec table management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def recreate_vec_table(conn: sqlite3.Connection, dim: int) -> None:
|
||||
"""Drop and recreate the chunks_vec virtual table (for reindexing)."""
|
||||
conn.execute("DROP TABLE IF EXISTS chunks_vec")
|
||||
conn.execute(
|
||||
f"CREATE VIRTUAL TABLE chunks_vec USING vec0(embedding float[{dim}], chunk_id integer)"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TERMINAL_STATUSES = {"done", "failed", "skipped"}
|
||||
|
||||
|
||||
def create_job(
|
||||
conn: sqlite3.Connection,
|
||||
filename: str,
|
||||
staging_path: str,
|
||||
doc_type: Optional[str] = None,
|
||||
tags_json: str = "[]",
|
||||
title: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Create a new ingest job and return its id."""
|
||||
cur = conn.execute(
|
||||
"INSERT INTO jobs(filename, staging_path, doc_type, tags_json, title) VALUES (?, ?, ?, ?, ?)",
|
||||
(filename, staging_path, doc_type, tags_json, title),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def get_job(conn: sqlite3.Connection, job_id: int) -> Optional[sqlite3.Row]:
|
||||
"""Return a single job row by id, or None."""
|
||||
return conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
|
||||
|
||||
def list_jobs(conn: sqlite3.Connection, status: Optional[str] = None) -> list[sqlite3.Row]:
|
||||
"""Return jobs ordered newest first, optionally filtered by status."""
|
||||
if status:
|
||||
return conn.execute(
|
||||
"SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC", (status,)
|
||||
).fetchall()
|
||||
return conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
|
||||
|
||||
def update_job_status(
|
||||
conn: sqlite3.Connection,
|
||||
job_id: int,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
document_id: Optional[int] = None,
|
||||
chunk_count: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Update a job's status and related fields. Sets completed_at for terminal states."""
|
||||
fields = ["status = ?"]
|
||||
params: list[Any] = [status]
|
||||
|
||||
if error is not None:
|
||||
fields.append("error = ?")
|
||||
params.append(error)
|
||||
|
||||
if document_id is not None:
|
||||
fields.append("document_id = ?")
|
||||
params.append(document_id)
|
||||
|
||||
if chunk_count is not None:
|
||||
fields.append("chunk_count = ?")
|
||||
params.append(chunk_count)
|
||||
|
||||
if status in _TERMINAL_STATUSES:
|
||||
fields.append("completed_at = current_timestamp")
|
||||
|
||||
params.append(job_id)
|
||||
conn.execute(f"UPDATE jobs SET {', '.join(fields)} WHERE id = ?", params)
|
||||
conn.commit()
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Embedding model management and text embedding utilities."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger("kb.embeddings")
|
||||
|
||||
_model: Optional[SentenceTransformer] = None
|
||||
_model_dim: Optional[int] = None
|
||||
|
||||
|
||||
def _resolve_device(device: str) -> str:
|
||||
"""Resolve device string, mapping 'auto' to the best available device."""
|
||||
if device == "auto":
|
||||
resolved = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info("Auto-resolved device to '%s'", resolved)
|
||||
return resolved
|
||||
return device
|
||||
|
||||
|
||||
def load_model(model_name: str, device: str = "cpu") -> int:
|
||||
"""Load a sentence-transformers model and return its embedding dimension.
|
||||
|
||||
The model is cached at module level so subsequent calls are no-ops unless
|
||||
the module globals are cleared.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name or local path.
|
||||
device: Target device — "cpu", "cuda", or "auto".
|
||||
|
||||
Returns:
|
||||
The embedding dimension of the loaded model.
|
||||
"""
|
||||
global _model, _model_dim
|
||||
|
||||
resolved_device = _resolve_device(device)
|
||||
|
||||
if resolved_device == "cuda":
|
||||
backend = "torch"
|
||||
else:
|
||||
backend = "onnx"
|
||||
|
||||
logger.info(
|
||||
"Loading model '%s' on device '%s' (backend=%s)",
|
||||
model_name,
|
||||
resolved_device,
|
||||
backend,
|
||||
)
|
||||
|
||||
_model = SentenceTransformer(
|
||||
model_name,
|
||||
device=resolved_device,
|
||||
backend=backend,
|
||||
)
|
||||
_model_dim = _model.get_sentence_embedding_dimension()
|
||||
|
||||
logger.info("Model loaded — embedding dimension: %d", _model_dim)
|
||||
return _model_dim
|
||||
|
||||
|
||||
def get_model_dim() -> int:
|
||||
"""Return the embedding dimension of the loaded model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no model has been loaded yet.
|
||||
"""
|
||||
if _model_dim is None:
|
||||
raise RuntimeError(
|
||||
"Embedding model not loaded. Call load_model() first."
|
||||
)
|
||||
return _model_dim
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: list[str],
|
||||
prefix: str = "",
|
||||
show_progress: bool = False,
|
||||
) -> list[list[float]]:
|
||||
"""Embed a list of texts using the cached model.
|
||||
|
||||
Args:
|
||||
texts: Strings to embed.
|
||||
prefix: Optional prefix prepended to each text before encoding.
|
||||
show_progress: Whether to display a progress bar.
|
||||
|
||||
Returns:
|
||||
A list of embedding vectors (each a list of floats).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no model has been loaded yet.
|
||||
"""
|
||||
if _model is None:
|
||||
raise RuntimeError(
|
||||
"Embedding model not loaded. Call load_model() first."
|
||||
)
|
||||
|
||||
if prefix:
|
||||
texts = [prefix + t for t in texts]
|
||||
|
||||
embeddings = _model.encode(
|
||||
texts,
|
||||
show_progress_bar=show_progress,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
|
||||
return embeddings.tolist()
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Chunking pipeline for source code files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def _fixed_token_chunks(text: str, max_tokens: int) -> list[str]:
|
||||
"""Split text into fixed-size token chunks by lines."""
|
||||
lines = text.split("\n")
|
||||
pieces: list[str] = []
|
||||
current: list[str] = []
|
||||
current_len = 0
|
||||
|
||||
for line in lines:
|
||||
line_tokens = _approx_tokens(line)
|
||||
if current and current_len + line_tokens > max_tokens:
|
||||
pieces.append("\n".join(current))
|
||||
current = [line]
|
||||
current_len = line_tokens
|
||||
else:
|
||||
current.append(line)
|
||||
current_len += line_tokens
|
||||
|
||||
if current:
|
||||
pieces.append("\n".join(current))
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
def _chunk_python(text: str, max_tokens: int) -> list[dict]:
|
||||
"""Use the ast module to extract top-level classes and functions."""
|
||||
lines = text.split("\n")
|
||||
|
||||
try:
|
||||
tree = ast.parse(text)
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
# Collect top-level class and function definitions
|
||||
regions: list[tuple[int, int, str]] = [] # (start_line, end_line, name)
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
continue
|
||||
|
||||
start = node.lineno - 1 # 0-indexed
|
||||
|
||||
# Include preceding comments and decorators
|
||||
first_line = node.lineno - 1
|
||||
if node.decorator_list:
|
||||
first_line = node.decorator_list[0].lineno - 1
|
||||
# Walk backwards for comment lines
|
||||
scan = first_line - 1
|
||||
while scan >= 0 and (lines[scan].strip().startswith("#") or not lines[scan].strip()):
|
||||
if lines[scan].strip().startswith("#"):
|
||||
first_line = scan
|
||||
scan -= 1
|
||||
|
||||
start = first_line
|
||||
end = node.end_lineno # 1-indexed, inclusive
|
||||
|
||||
prefix = "class " if isinstance(node, ast.ClassDef) else "def "
|
||||
regions.append((start, end, f"{prefix}{node.name}"))
|
||||
|
||||
if not regions:
|
||||
return []
|
||||
|
||||
# Sort by start line
|
||||
regions.sort(key=lambda r: r[0])
|
||||
|
||||
chunks: list[dict] = []
|
||||
chunk_index = 0
|
||||
prev_end = 0
|
||||
|
||||
for start, end, name in regions:
|
||||
# Capture any module-level code between definitions
|
||||
if start > prev_end:
|
||||
preamble = "\n".join(lines[prev_end:start]).strip()
|
||||
if preamble and _approx_tokens(preamble) > 10:
|
||||
chunks.append({
|
||||
"text": preamble,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {},
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
block = "\n".join(lines[start:end]).rstrip()
|
||||
if _approx_tokens(block) > max_tokens:
|
||||
for piece in _fixed_token_chunks(block, max_tokens):
|
||||
chunks.append({
|
||||
"text": piece,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {"name": name},
|
||||
})
|
||||
chunk_index += 1
|
||||
else:
|
||||
chunks.append({
|
||||
"text": block,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {"name": name},
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
prev_end = end
|
||||
|
||||
# Trailing module-level code
|
||||
if prev_end < len(lines):
|
||||
tail = "\n".join(lines[prev_end:]).strip()
|
||||
if tail and _approx_tokens(tail) > 10:
|
||||
chunks.append({
|
||||
"text": tail,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_by_regex(text: str, pattern: str, max_tokens: int) -> list[dict]:
|
||||
"""Split source code at regex-matched function boundaries."""
|
||||
lines = text.split("\n")
|
||||
boundaries: list[tuple[int, str]] = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
m = re.match(pattern, line)
|
||||
if m:
|
||||
boundaries.append((i, m.group(0).strip()))
|
||||
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
chunks: list[dict] = []
|
||||
chunk_index = 0
|
||||
|
||||
# Content before first match
|
||||
if boundaries[0][0] > 0:
|
||||
preamble = "\n".join(lines[: boundaries[0][0]]).strip()
|
||||
if preamble:
|
||||
chunks.append({
|
||||
"text": preamble,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {},
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
for idx, (start, name) in enumerate(boundaries):
|
||||
end = boundaries[idx + 1][0] if idx + 1 < len(boundaries) else len(lines)
|
||||
block = "\n".join(lines[start:end]).rstrip()
|
||||
|
||||
if _approx_tokens(block) > max_tokens:
|
||||
for piece in _fixed_token_chunks(block, max_tokens):
|
||||
chunks.append({
|
||||
"text": piece,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {"name": name},
|
||||
})
|
||||
chunk_index += 1
|
||||
else:
|
||||
chunks.append({
|
||||
"text": block,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": {"name": name},
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_code(
|
||||
text: str,
|
||||
language: str | None,
|
||||
max_tokens: int = 1024,
|
||||
) -> list[dict]:
|
||||
"""Split source code into chunks using language-aware strategies.
|
||||
|
||||
Returns a list of chunk dicts, each containing:
|
||||
text, chunk_index, metadata
|
||||
"""
|
||||
chunks: list[dict] = []
|
||||
|
||||
if language == "python":
|
||||
chunks = _chunk_python(text, max_tokens)
|
||||
elif language == "bash":
|
||||
chunks = _chunk_by_regex(
|
||||
text, r"^(?:\w+\s*\(\)|function\s+\w+)", max_tokens
|
||||
)
|
||||
elif language == "go":
|
||||
chunks = _chunk_by_regex(text, r"^func\s+", max_tokens)
|
||||
|
||||
# Fallback: fixed-size token chunking
|
||||
if not chunks:
|
||||
for idx, piece in enumerate(_fixed_token_chunks(text, max_tokens)):
|
||||
piece = piece.strip()
|
||||
if piece:
|
||||
chunks.append({
|
||||
"text": piece,
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Detect document type and language from file extension."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
SUPPORTED_EXTENSIONS: dict[str, tuple[str, str | None]] = {
|
||||
".pdf": ("pdf", None),
|
||||
".docx": ("pdf", None),
|
||||
".html": ("pdf", None),
|
||||
".md": ("markdown", None),
|
||||
".txt": ("note", None),
|
||||
".py": ("code", "python"),
|
||||
".sh": ("code", "bash"),
|
||||
".go": ("code", "go"),
|
||||
}
|
||||
|
||||
|
||||
def is_supported(path: Path) -> bool:
|
||||
"""Check if the file extension is supported for ingestion."""
|
||||
return path.suffix.lower() in SUPPORTED_EXTENSIONS
|
||||
|
||||
|
||||
def detect_type(
|
||||
path: Path,
|
||||
force_type: str | None = None,
|
||||
force_language: str | None = None,
|
||||
) -> tuple[str, str | None]:
|
||||
"""Return (doc_type, language) for the given file path.
|
||||
|
||||
Uses force_type / force_language when provided, otherwise falls back to
|
||||
extension-based lookup. Raises ValueError for unsupported extensions.
|
||||
"""
|
||||
ext = path.suffix.lower()
|
||||
|
||||
if ext not in SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension '{ext}'. "
|
||||
f"Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
doc_type, language = SUPPORTED_EXTENSIONS[ext]
|
||||
|
||||
if force_type is not None:
|
||||
doc_type = force_type
|
||||
if force_language is not None:
|
||||
language = force_language
|
||||
|
||||
return doc_type, language
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Chunking pipeline for PDF / DOCX / HTML documents via Docling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress noisy Docling / RapidOCR logging
|
||||
for _logger_name in (
|
||||
"docling",
|
||||
"docling.document_converter",
|
||||
"docling_core",
|
||||
"rapidocr",
|
||||
"rapidocr_onnxruntime",
|
||||
):
|
||||
logging.getLogger(_logger_name).setLevel(logging.WARNING)
|
||||
|
||||
from docling.datamodel.base_models import InputFormat # noqa: E402
|
||||
from docling.datamodel.pipeline_options import ( # noqa: E402
|
||||
AcceleratorOptions,
|
||||
PdfPipelineOptions,
|
||||
RapidOcrOptions,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption # noqa: E402
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # noqa: E402
|
||||
HierarchicalChunker,
|
||||
)
|
||||
|
||||
|
||||
def _fixed_size_chunks(text: str, max_chars: int = 2000) -> list[str]:
|
||||
"""Split text into fixed-size pieces as a fallback."""
|
||||
chunks: list[str] = []
|
||||
for i in range(0, len(text), max_chars):
|
||||
chunk = text[i : i + max_chars].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_document(
|
||||
file_path: Path,
|
||||
ingest_device: str = "cpu",
|
||||
) -> list[dict]:
|
||||
"""Convert and chunk a PDF/DOCX/HTML document using Docling.
|
||||
|
||||
Returns a list of chunk dicts, each containing:
|
||||
text, chunk_index, metadata
|
||||
"""
|
||||
accelerator = AcceleratorOptions(device=ingest_device)
|
||||
ocr_options = RapidOcrOptions()
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
accelerator_options=accelerator,
|
||||
do_ocr=True,
|
||||
ocr_options=ocr_options,
|
||||
bitmap_area_threshold=0.25,
|
||||
)
|
||||
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options),
|
||||
}
|
||||
)
|
||||
|
||||
result = converter.convert(str(file_path))
|
||||
doc = result.document
|
||||
|
||||
# Primary: hierarchical chunking
|
||||
chunker = HierarchicalChunker()
|
||||
raw_chunks = list(chunker.chunk(doc))
|
||||
|
||||
chunks: list[dict] = []
|
||||
for idx, chunk in enumerate(raw_chunks):
|
||||
text = chunk.text.strip() if hasattr(chunk, "text") else str(chunk).strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
metadata: dict = {}
|
||||
|
||||
# Extract page numbers from chunk metadata when available
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
meta = chunk.meta
|
||||
if hasattr(meta, "page") and meta.page is not None:
|
||||
metadata["page"] = meta.page
|
||||
if hasattr(meta, "pages") and meta.pages:
|
||||
metadata["pages"] = meta.pages
|
||||
if hasattr(meta, "headings") and meta.headings:
|
||||
metadata["section_header"] = " > ".join(meta.headings)
|
||||
|
||||
chunks.append({
|
||||
"text": text,
|
||||
"chunk_index": idx,
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
# Fallback: fixed-size chunking when hierarchy produces nothing
|
||||
if not chunks:
|
||||
full_text = doc.export_to_text() if hasattr(doc, "export_to_text") else ""
|
||||
if not full_text and hasattr(doc, "text"):
|
||||
full_text = doc.text
|
||||
for idx, piece in enumerate(_fixed_size_chunks(full_text)):
|
||||
chunks.append({
|
||||
"text": piece,
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Chunking pipeline for Markdown documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def _split_at_paragraphs(text: str, max_tokens: int) -> list[str]:
|
||||
"""Split text into pieces that fit within max_tokens at paragraph boundaries."""
|
||||
paragraphs = re.split(r"\n{2,}", text)
|
||||
pieces: list[str] = []
|
||||
current: list[str] = []
|
||||
current_len = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para_tokens = _approx_tokens(para)
|
||||
if current and current_len + para_tokens > max_tokens:
|
||||
pieces.append("\n\n".join(current))
|
||||
current = [para]
|
||||
current_len = para_tokens
|
||||
else:
|
||||
current.append(para)
|
||||
current_len += para_tokens
|
||||
|
||||
if current:
|
||||
pieces.append("\n\n".join(current))
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
def chunk_markdown(
|
||||
text: str,
|
||||
max_tokens: int = 1024,
|
||||
min_tokens: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Split markdown text into chunks based on heading structure.
|
||||
|
||||
Returns a list of chunk dicts, each containing:
|
||||
text, chunk_index, metadata
|
||||
"""
|
||||
lines = text.split("\n")
|
||||
|
||||
# Split into sections by headings
|
||||
sections: list[tuple[str | None, str]] = [] # (heading, body)
|
||||
current_heading: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in lines:
|
||||
if re.match(r"^#{1,6}\s+", line):
|
||||
# Flush previous section
|
||||
if current_lines or current_heading is not None:
|
||||
sections.append((current_heading, "\n".join(current_lines).strip()))
|
||||
current_heading = line.strip()
|
||||
current_lines = []
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
# Flush last section
|
||||
if current_lines or current_heading is not None:
|
||||
sections.append((current_heading, "\n".join(current_lines).strip()))
|
||||
|
||||
# If there was content before any heading, capture it
|
||||
if not sections:
|
||||
sections.append((None, text.strip()))
|
||||
|
||||
# Build heading hierarchy for context
|
||||
heading_stack: list[tuple[int, str]] = []
|
||||
|
||||
def _update_stack(heading: str | None) -> str | None:
|
||||
if heading is None:
|
||||
return " > ".join(h for _, h in heading_stack) if heading_stack else None
|
||||
match = re.match(r"^(#{1,6})\s+(.*)", heading)
|
||||
if not match:
|
||||
return heading
|
||||
level = len(match.group(1))
|
||||
title = match.group(2).strip()
|
||||
# Pop deeper or equal headings
|
||||
while heading_stack and heading_stack[-1][0] >= level:
|
||||
heading_stack.pop()
|
||||
heading_stack.append((level, title))
|
||||
return " > ".join(h for _, h in heading_stack)
|
||||
|
||||
# Merge small sections with the next section
|
||||
merged: list[tuple[str | None, str]] = []
|
||||
for heading, body in sections:
|
||||
section_text = f"{heading}\n{body}".strip() if heading else body
|
||||
if merged and _approx_tokens(merged[-1][1]) < min_tokens:
|
||||
prev_heading, prev_body = merged[-1]
|
||||
merged[-1] = (prev_heading, f"{prev_body}\n\n{section_text}".strip())
|
||||
else:
|
||||
merged.append((heading, section_text))
|
||||
|
||||
# Build chunks, splitting large sections at paragraph boundaries
|
||||
chunks: list[dict] = []
|
||||
chunk_index = 0
|
||||
|
||||
for heading, body in merged:
|
||||
section_header = _update_stack(heading)
|
||||
tokens = _approx_tokens(body)
|
||||
|
||||
if tokens <= max_tokens:
|
||||
if body:
|
||||
metadata: dict = {}
|
||||
if section_header:
|
||||
metadata["section_header"] = section_header
|
||||
chunks.append({
|
||||
"text": body,
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": metadata,
|
||||
})
|
||||
chunk_index += 1
|
||||
else:
|
||||
# Split large section at paragraph boundaries
|
||||
for piece in _split_at_paragraphs(body, max_tokens):
|
||||
if piece.strip():
|
||||
metadata = {}
|
||||
if section_header:
|
||||
metadata["section_header"] = section_header
|
||||
chunks.append({
|
||||
"text": piece.strip(),
|
||||
"chunk_index": chunk_index,
|
||||
"metadata": metadata,
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Chunking pipeline for plain-text notes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def auto_title(text: str) -> str:
|
||||
"""Derive a short title from the first line of text.
|
||||
|
||||
Strips leading markdown ``#`` characters and trims to 80 characters.
|
||||
"""
|
||||
first_line = text.split("\n", 1)[0].strip()
|
||||
first_line = first_line.lstrip("#").strip()
|
||||
if len(first_line) > 80:
|
||||
first_line = first_line[:80].rstrip()
|
||||
return first_line
|
||||
|
||||
|
||||
def chunk_note(text: str) -> list[dict]:
|
||||
"""Return the entire text as a single chunk.
|
||||
|
||||
Returns a list with one chunk dict containing:
|
||||
text, chunk_index, metadata
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"text": text.strip(),
|
||||
"chunk_index": 0,
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth
|
||||
@@ -0,0 +1,38 @@
|
||||
"""API key authentication middleware."""
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip auth if no API key is configured
|
||||
if cfg.api_key is None:
|
||||
return await call_next(request)
|
||||
|
||||
# Always allow health checks without auth
|
||||
if request.url.path == "/api/v1/health":
|
||||
return await call_next(request)
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
{"detail": "Missing Authorization header."},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
expected = f"Bearer {cfg.api_key}"
|
||||
if auth_header != expected:
|
||||
return JSONResponse(
|
||||
{"detail": "Invalid API key."},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
app.add_middleware(AuthMiddleware)
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Document management endpoints — list, view, and delete documents."""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Query
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection
|
||||
|
||||
|
||||
@app.get("/api/v1/documents")
|
||||
async def list_documents(
|
||||
type: Optional[str] = Query(None),
|
||||
tags: Optional[str] = Query(None),
|
||||
):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
sql = """
|
||||
SELECT d.id, d.title, d.doc_type,
|
||||
(SELECT COUNT(*) FROM chunks c WHERE c.document_id = d.id) AS chunk_count,
|
||||
d.created_at
|
||||
FROM documents d
|
||||
"""
|
||||
joins: list[str] = []
|
||||
where: list[str] = []
|
||||
params: list = []
|
||||
|
||||
if type:
|
||||
where.append("d.doc_type = ?")
|
||||
params.append(type)
|
||||
|
||||
if tags:
|
||||
tag_list = [t.strip() for t in tags.split(",") if t.strip()]
|
||||
for i, tag in enumerate(tag_list):
|
||||
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
||||
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
||||
where.append(f"t{i}.name = ?")
|
||||
params.append(tag)
|
||||
|
||||
if joins:
|
||||
sql += " " + " ".join(joins)
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
|
||||
sql += " ORDER BY d.created_at DESC"
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
doc_id = row["id"]
|
||||
tag_rows = conn.execute(
|
||||
"""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
|
||||
results.append({
|
||||
"id": row["id"],
|
||||
"title": row["title"],
|
||||
"doc_type": row["doc_type"],
|
||||
"tags": [t["name"] for t in tag_rows],
|
||||
"chunk_count": row["chunk_count"],
|
||||
"created_at": row["created_at"],
|
||||
})
|
||||
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.get("/api/v1/documents/{doc_id}")
|
||||
async def get_document(doc_id: int):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc = conn.execute(
|
||||
"SELECT * FROM documents WHERE id = ?", (doc_id,)
|
||||
).fetchone()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found.")
|
||||
|
||||
chunks = conn.execute(
|
||||
"SELECT * FROM chunks WHERE document_id = ? ORDER BY chunk_index",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
|
||||
tag_rows = conn.execute(
|
||||
"""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
|
||||
return {
|
||||
**dict(doc),
|
||||
"tags": [t["name"] for t in tag_rows],
|
||||
"chunks": [dict(c) for c in chunks],
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.delete("/api/v1/documents/{doc_id}")
|
||||
async def delete_document(doc_id: int):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc = conn.execute(
|
||||
"SELECT id, title FROM documents WHERE id = ?", (doc_id,)
|
||||
).fetchone()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found.")
|
||||
|
||||
# Get chunk IDs for embedding cleanup
|
||||
chunk_ids = conn.execute(
|
||||
"SELECT id FROM chunks WHERE document_id = ?", (doc_id,)
|
||||
).fetchall()
|
||||
|
||||
# Delete embeddings from vec table
|
||||
for row in chunk_ids:
|
||||
conn.execute(
|
||||
"DELETE FROM chunks_vec WHERE chunk_id = ?", (row["id"],)
|
||||
)
|
||||
|
||||
# Delete document (cascades to chunks, document_tags)
|
||||
conn.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"document_id": doc_id,
|
||||
"title": doc["title"],
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Health check endpoint."""
|
||||
|
||||
import main
|
||||
from main import app
|
||||
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health():
|
||||
if not main.ready:
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"status": "starting"}, status_code=503)
|
||||
return {"status": "healthy"}
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Job management endpoints — submit files/notes for ingestion and track progress."""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, UploadFile, File, Form, Query
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection, create_job, get_job, list_jobs
|
||||
from kb.staging import stage_file, stage_note
|
||||
|
||||
|
||||
@app.post("/api/v1/jobs", status_code=202)
|
||||
async def submit_job(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
note: Optional[str] = Form(None),
|
||||
title: Optional[str] = Form(None),
|
||||
tags: Optional[str] = Form(None),
|
||||
doc_type: Optional[str] = Form(None),
|
||||
):
|
||||
if not file and not note:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Either 'file' or 'note' must be provided.",
|
||||
)
|
||||
|
||||
if file:
|
||||
content = await file.read()
|
||||
staging_path = stage_file(cfg.staging_dir, file.filename, content)
|
||||
filename = file.filename
|
||||
else:
|
||||
staging_path = stage_note(cfg.staging_dir, title or "note", note)
|
||||
filename = staging_path.name
|
||||
|
||||
tags_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else []
|
||||
tags_json = json.dumps(tags_list)
|
||||
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
job_id = create_job(conn, filename, str(staging_path), doc_type, tags_json, title)
|
||||
return {"job_id": job_id, "status": "queued", "filename": filename}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.get("/api/v1/jobs")
|
||||
async def list_all_jobs(status: Optional[str] = Query(None)):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
rows = list_jobs(conn, status)
|
||||
return [dict(row) for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.get("/api/v1/jobs/{job_id}")
|
||||
async def get_single_job(job_id: int):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
row = get_job(conn, job_id)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Job not found.")
|
||||
return dict(row)
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Reindex endpoint — re-embed all chunks with the current model."""
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection, recreate_vec_table
|
||||
from kb.embeddings import embed_texts, get_model_dim
|
||||
|
||||
logger = logging.getLogger("kb.routes.reindex")
|
||||
|
||||
BATCH_SIZE = 256
|
||||
|
||||
|
||||
@app.post("/api/v1/reindex")
|
||||
async def reindex():
|
||||
dim = get_model_dim()
|
||||
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
# Fetch all chunks
|
||||
rows = conn.execute("SELECT id, text FROM chunks ORDER BY id").fetchall()
|
||||
chunk_ids = [row["id"] for row in rows]
|
||||
chunk_texts = [row["text"] for row in rows]
|
||||
|
||||
logger.info("Reindexing %d chunks with model '%s'", len(chunk_ids), cfg.model)
|
||||
|
||||
# Recreate the vec table
|
||||
recreate_vec_table(conn, dim)
|
||||
|
||||
# Embed and insert in batches
|
||||
for i in range(0, len(chunk_ids), BATCH_SIZE):
|
||||
batch_ids = chunk_ids[i : i + BATCH_SIZE]
|
||||
batch_texts = chunk_texts[i : i + BATCH_SIZE]
|
||||
|
||||
embeddings = embed_texts(batch_texts)
|
||||
|
||||
for chunk_id, embedding in zip(batch_ids, embeddings):
|
||||
blob = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
conn.execute(
|
||||
"INSERT INTO chunks_vec(embedding, chunk_id) VALUES (?, ?)",
|
||||
(blob, chunk_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info("Reindex complete: %d chunks", len(chunk_ids))
|
||||
|
||||
return {
|
||||
"chunks_reindexed": len(chunk_ids),
|
||||
"model": cfg.model,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Search endpoint — hybrid FTS5 + vector search."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection
|
||||
from kb.search import hybrid_search
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top: int = 10
|
||||
tags: Optional[list[str]] = None
|
||||
doc_type: Optional[str] = None
|
||||
fts_only: bool = False
|
||||
vec_only: bool = False
|
||||
threshold: Optional[float] = None
|
||||
|
||||
|
||||
@app.post("/api/v1/search")
|
||||
async def search(req: SearchRequest):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
result = hybrid_search(
|
||||
conn,
|
||||
req.query,
|
||||
cfg,
|
||||
top=req.top,
|
||||
tags=req.tags,
|
||||
doc_type=req.doc_type,
|
||||
fts_only=req.fts_only,
|
||||
vec_only=req.vec_only,
|
||||
threshold=req.threshold,
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,67 @@
|
||||
"""System status endpoint — model info, DB stats, and queue stats."""
|
||||
|
||||
import os
|
||||
|
||||
from main import app, __version__
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection
|
||||
from kb.embeddings import get_model_dim
|
||||
|
||||
|
||||
@app.get("/api/v1/status")
|
||||
async def status():
|
||||
# Device info
|
||||
device_name = cfg.device
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
# Document counts by type
|
||||
type_rows = conn.execute(
|
||||
"SELECT doc_type, COUNT(*) AS count FROM documents GROUP BY doc_type"
|
||||
).fetchall()
|
||||
doc_counts = {row["doc_type"]: row["count"] for row in type_rows}
|
||||
|
||||
# Total chunks
|
||||
total_chunks = conn.execute("SELECT COUNT(*) AS n FROM chunks").fetchone()["n"]
|
||||
|
||||
# DB file size
|
||||
db_size = 0
|
||||
try:
|
||||
db_size = os.path.getsize(cfg.db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Queue stats
|
||||
queue_rows = conn.execute(
|
||||
"""
|
||||
SELECT status, COUNT(*) AS count
|
||||
FROM jobs
|
||||
WHERE status IN ('queued', 'processing')
|
||||
GROUP BY status
|
||||
"""
|
||||
).fetchall()
|
||||
queue_stats = {row["status"]: row["count"] for row in queue_rows}
|
||||
|
||||
return {
|
||||
"version": __version__,
|
||||
"model_name": cfg.model,
|
||||
"embedding_dim": get_model_dim(),
|
||||
"device": device_name,
|
||||
"db": {
|
||||
"documents_by_type": doc_counts,
|
||||
"total_chunks": total_chunks,
|
||||
"db_size_bytes": db_size,
|
||||
},
|
||||
"queue": {
|
||||
"queued": queue_stats.get("queued", 0),
|
||||
"processing": queue_stats.get("processing", 0),
|
||||
},
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Tag management endpoints."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import get_connection, tag_document, untag_document
|
||||
|
||||
|
||||
@app.get("/api/v1/tags")
|
||||
async def list_tags():
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT t.name, COUNT(dt.document_id) AS count
|
||||
FROM tags t
|
||||
LEFT JOIN document_tags dt ON t.id = dt.tag_id
|
||||
GROUP BY t.id, t.name
|
||||
ORDER BY t.name
|
||||
"""
|
||||
).fetchall()
|
||||
return [{"name": row["name"], "count": row["count"]} for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TagUpdateRequest(BaseModel):
|
||||
add: Optional[list[str]] = None
|
||||
remove: Optional[list[str]] = None
|
||||
|
||||
|
||||
@app.put("/api/v1/documents/{doc_id}/tags")
|
||||
async def update_document_tags(doc_id: int, req: TagUpdateRequest):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc = conn.execute(
|
||||
"SELECT id FROM documents WHERE id = ?", (doc_id,)
|
||||
).fetchone()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found.")
|
||||
|
||||
if req.add:
|
||||
tag_document(conn, doc_id, req.add)
|
||||
if req.remove:
|
||||
untag_document(conn, doc_id, req.remove)
|
||||
|
||||
tag_rows = conn.execute(
|
||||
"""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
|
||||
return {"document_id": doc_id, "tags": [t["name"] for t in tag_rows]}
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,290 @@
|
||||
"""Hybrid search — FTS5 + sqlite-vec with Reciprocal Rank Fusion."""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import sqlite3
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
conn: sqlite3.Connection,
|
||||
query: str,
|
||||
cfg,
|
||||
top: int = 10,
|
||||
tags: list[str] | None = None,
|
||||
doc_type: str | None = None,
|
||||
fts_only: bool = False,
|
||||
vec_only: bool = False,
|
||||
threshold: float | None = None,
|
||||
) -> dict:
|
||||
"""Run hybrid search and return merged, enriched results.
|
||||
|
||||
Args:
|
||||
conn: SQLite connection (with row_factory = sqlite3.Row).
|
||||
query: User search query string.
|
||||
cfg: Config object with ``model`` and ``device`` attributes.
|
||||
top: Maximum number of results to return.
|
||||
tags: Optional tag filter — documents must have *all* listed tags.
|
||||
doc_type: Optional document-type filter.
|
||||
fts_only: Only use FTS5 (skip vector search).
|
||||
vec_only: Only use vector search (skip FTS5).
|
||||
threshold: Optional minimum score; results below are dropped.
|
||||
|
||||
Returns:
|
||||
Dict with keys: query, results, total_matches, returned.
|
||||
"""
|
||||
candidate_count = top * 3
|
||||
|
||||
fts_results: dict[int, float] = {}
|
||||
vec_results: dict[int, float] = {}
|
||||
|
||||
if not vec_only:
|
||||
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
|
||||
|
||||
if not fts_only:
|
||||
vec_results = _vector_search(conn, query, candidate_count, tags, doc_type)
|
||||
|
||||
# --- merge ---------------------------------------------------------------
|
||||
if fts_only:
|
||||
merged = sorted(fts_results.items(), key=lambda x: x[1], reverse=True)
|
||||
elif vec_only:
|
||||
merged = sorted(vec_results.items(), key=lambda x: x[1], reverse=True)
|
||||
else:
|
||||
merged = _rrf_merge(fts_results, vec_results)
|
||||
|
||||
# Apply threshold filter — use config default if not specified per-query
|
||||
effective_threshold = threshold if threshold is not None else cfg.search_threshold
|
||||
if effective_threshold > 0:
|
||||
merged = [(cid, score) for cid, score in merged if score >= effective_threshold]
|
||||
|
||||
total_matches = len(merged)
|
||||
merged = merged[:top]
|
||||
|
||||
# --- enrich --------------------------------------------------------------
|
||||
results = _enrich(conn, merged)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"results": results,
|
||||
"total_matches": total_matches,
|
||||
"returned": len(results),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fts_search(
|
||||
conn: sqlite3.Connection,
|
||||
query: str,
|
||||
limit: int,
|
||||
tags: list[str] | None,
|
||||
doc_type: str | None,
|
||||
) -> dict[int, float]:
|
||||
"""FTS5 search on ``chunks_fts``.
|
||||
|
||||
Returns:
|
||||
{chunk_id: bm25_score} where scores are positive (higher = better).
|
||||
"""
|
||||
sql = "SELECT f.rowid AS chunk_id, bm25(chunks_fts) AS rank FROM chunks_fts f"
|
||||
joins: list[str] = []
|
||||
where: list[str] = ["chunks_fts MATCH ?"]
|
||||
params: list = [query]
|
||||
|
||||
if tags or doc_type:
|
||||
joins.append("JOIN chunks c ON f.rowid = c.id")
|
||||
joins.append("JOIN documents d ON c.document_id = d.id")
|
||||
|
||||
if doc_type:
|
||||
where.append("d.doc_type = ?")
|
||||
params.append(doc_type)
|
||||
|
||||
if tags:
|
||||
for i, tag in enumerate(tags):
|
||||
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
||||
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
||||
where.append(f"t{i}.name = ?")
|
||||
params.append(tag.strip().lower())
|
||||
|
||||
sql += " " + " ".join(joins)
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " ORDER BY rank LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
# BM25 returns negative values (lower = better match); negate so
|
||||
# higher = better.
|
||||
return {row[0]: -row[1] for row in rows}
|
||||
|
||||
|
||||
def _vector_search(
|
||||
conn: sqlite3.Connection,
|
||||
query: str,
|
||||
limit: int,
|
||||
tags: list[str] | None,
|
||||
doc_type: str | None,
|
||||
) -> dict[int, float]:
|
||||
"""Embed *query* and search ``chunks_vec`` via sqlite-vec.
|
||||
|
||||
Returns:
|
||||
{chunk_id: similarity} where similarity = 1 / (1 + distance).
|
||||
"""
|
||||
from kb.embeddings import embed_texts
|
||||
|
||||
query_embedding = embed_texts([query])[0]
|
||||
blob = struct.pack(f"{len(query_embedding)}f", *query_embedding)
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT chunk_id, distance
|
||||
FROM chunks_vec
|
||||
WHERE embedding MATCH ?
|
||||
ORDER BY distance
|
||||
LIMIT ?
|
||||
""",
|
||||
(blob, limit),
|
||||
).fetchall()
|
||||
|
||||
results: dict[int, float] = {}
|
||||
for row in rows:
|
||||
chunk_id = row[0]
|
||||
distance = row[1]
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
|
||||
# Post-hoc tag / doc_type filtering for vector results
|
||||
if tags or doc_type:
|
||||
if not _passes_filters(conn, chunk_id, tags, doc_type):
|
||||
continue
|
||||
|
||||
results[chunk_id] = similarity
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _passes_filters(
|
||||
conn: sqlite3.Connection,
|
||||
chunk_id: int,
|
||||
tags: list[str] | None,
|
||||
doc_type: str | None,
|
||||
) -> bool:
|
||||
"""Return True if chunk passes tag and doc_type filters."""
|
||||
sql = """
|
||||
SELECT d.id FROM chunks c
|
||||
JOIN documents d ON c.document_id = d.id
|
||||
WHERE c.id = ?
|
||||
"""
|
||||
params: list = [chunk_id]
|
||||
|
||||
if doc_type:
|
||||
sql += " AND d.doc_type = ?"
|
||||
params.append(doc_type)
|
||||
|
||||
doc_row = conn.execute(sql, params).fetchone()
|
||||
if not doc_row:
|
||||
return False
|
||||
|
||||
if tags:
|
||||
doc_id = doc_row[0]
|
||||
placeholders = ",".join("?" * len(tags))
|
||||
normalised = [t.strip().lower() for t in tags]
|
||||
count = conn.execute(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT t.name) FROM document_tags dt
|
||||
JOIN tags t ON dt.tag_id = t.id
|
||||
WHERE dt.document_id = ? AND t.name IN ({placeholders})
|
||||
""",
|
||||
[doc_id, *normalised],
|
||||
).fetchone()[0]
|
||||
if count < len(tags):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _rrf_merge(
|
||||
fts_results: dict[int, float],
|
||||
vec_results: dict[int, float],
|
||||
k: int = 60,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Reciprocal Rank Fusion over two scored result sets.
|
||||
|
||||
Each set is ranked independently (highest score first, rank starts at 1).
|
||||
RRF score for a document = sum of 1/(k + rank) across sets it appears in.
|
||||
|
||||
Returns:
|
||||
Sorted list of (chunk_id, rrf_score), highest first.
|
||||
"""
|
||||
fts_ranked = _rank_by_score(fts_results)
|
||||
vec_ranked = _rank_by_score(vec_results)
|
||||
|
||||
all_ids = set(fts_ranked) | set(vec_ranked)
|
||||
scores: list[tuple[int, float]] = []
|
||||
|
||||
for chunk_id in all_ids:
|
||||
rrf = 0.0
|
||||
if chunk_id in fts_ranked:
|
||||
rrf += 1.0 / (k + fts_ranked[chunk_id])
|
||||
if chunk_id in vec_ranked:
|
||||
rrf += 1.0 / (k + vec_ranked[chunk_id])
|
||||
scores.append((chunk_id, rrf))
|
||||
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores
|
||||
|
||||
|
||||
def _rank_by_score(results: dict[int, float]) -> dict[int, int]:
|
||||
"""Return {id: 1-based rank} sorted by score descending."""
|
||||
ordered = sorted(results, key=results.get, reverse=True)
|
||||
return {cid: rank for rank, cid in enumerate(ordered, start=1)}
|
||||
|
||||
|
||||
def _enrich(
|
||||
conn: sqlite3.Connection,
|
||||
merged: list[tuple[int, float]],
|
||||
) -> list[dict]:
|
||||
"""Fetch chunk text, document metadata, chunk metadata, and tags."""
|
||||
results: list[dict] = []
|
||||
|
||||
for chunk_id, score in merged:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT c.id, c.text, c.chunk_index, c.metadata AS chunk_meta,
|
||||
d.id AS doc_id, d.title, d.doc_type, d.source_path,
|
||||
d.created_at
|
||||
FROM chunks c
|
||||
JOIN documents d ON c.document_id = d.id
|
||||
WHERE c.id = ?
|
||||
""",
|
||||
(chunk_id,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
continue
|
||||
|
||||
chunk_meta = json.loads(row[3]) if row[3] else {}
|
||||
|
||||
tag_rows = conn.execute(
|
||||
"""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""",
|
||||
(row[4],), # doc_id
|
||||
).fetchall()
|
||||
|
||||
results.append({
|
||||
"chunk_id": row[0],
|
||||
"score": round(score, 6),
|
||||
"text": row[1],
|
||||
"chunk_index": row[2],
|
||||
"chunk_metadata": chunk_meta,
|
||||
"title": row[5],
|
||||
"doc_type": row[6],
|
||||
"source_path": row[7],
|
||||
"created_at": row[8],
|
||||
"tags": [t[0] for t in tag_rows],
|
||||
})
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Staging area for files awaiting ingestion."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger("kb.staging")
|
||||
|
||||
|
||||
def stage_file(staging_dir: Path, filename: str, content: bytes) -> Path:
|
||||
"""Write raw bytes to a uniquely-named file in the staging directory.
|
||||
|
||||
The staged file is named ``{uuid}_{filename}`` to avoid collisions.
|
||||
|
||||
Returns:
|
||||
The path to the newly created staged file.
|
||||
"""
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = staging_dir / f"{uuid.uuid4()}_{filename}"
|
||||
dest.write_bytes(content)
|
||||
logger.debug("Staged file: %s (%d bytes)", dest, len(content))
|
||||
return dest
|
||||
|
||||
|
||||
def stage_note(staging_dir: Path, title: str, text: str) -> Path:
|
||||
"""Write a text note to the staging directory.
|
||||
|
||||
The staged file is named ``{uuid}_{title}.note``.
|
||||
|
||||
Returns:
|
||||
The path to the newly created staged note file.
|
||||
"""
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = staging_dir / f"{uuid.uuid4()}_{title}.note"
|
||||
dest.write_text(text, encoding="utf-8")
|
||||
logger.debug("Staged note: %s (%d chars)", dest, len(text))
|
||||
return dest
|
||||
|
||||
|
||||
def cleanup(path: Path) -> None:
|
||||
"""Delete a staged file if it exists. Logs a warning on failure."""
|
||||
try:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
logger.debug("Cleaned up staged file: %s", path)
|
||||
except OSError as exc:
|
||||
logger.warning("Failed to clean up staged file %s: %s", path, exc)
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Async background worker for processing ingestion jobs."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from kb import config, database, embeddings, staging
|
||||
from kb.ingest import detector
|
||||
|
||||
logger = logging.getLogger("kb.worker")
|
||||
|
||||
|
||||
async def ingestion_worker() -> None:
|
||||
"""Main background loop that processes queued ingestion jobs.
|
||||
|
||||
Runs indefinitely until cancelled. Every 2 seconds it checks for the
|
||||
oldest queued job, marks it as *processing*, and delegates to
|
||||
:func:`_process_job` in a thread pool.
|
||||
"""
|
||||
logger.info("Ingestion worker started")
|
||||
while True:
|
||||
try:
|
||||
cfg = config.cfg
|
||||
conn = database.get_connection(cfg.db_path)
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM jobs WHERE status = 'queued' "
|
||||
"ORDER BY created_at ASC LIMIT 1"
|
||||
).fetchone()
|
||||
if row is None:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
job_id = row["id"]
|
||||
database.update_job_status(conn, job_id, "processing")
|
||||
logger.info("Processing job %d (%s)", job_id, row["filename"])
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
try:
|
||||
status, doc_id, chunk_count = await asyncio.to_thread(
|
||||
_process_job, row
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Job %d failed", job_id)
|
||||
conn = database.get_connection(cfg.db_path)
|
||||
try:
|
||||
database.update_job_status(
|
||||
conn, job_id, "failed", error=str(exc)
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
continue
|
||||
|
||||
conn = database.get_connection(cfg.db_path)
|
||||
try:
|
||||
database.update_job_status(
|
||||
conn,
|
||||
job_id,
|
||||
status,
|
||||
document_id=doc_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
"Job %d finished: status=%s doc_id=%s chunks=%s",
|
||||
job_id, status, doc_id, chunk_count,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Ingestion worker cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Unexpected error in ingestion worker loop")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
def _process_job(job_row) -> tuple[str, int | None, int]:
|
||||
"""Synchronously process a single ingestion job.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(status, document_id, chunk_count)`` where *status* is
|
||||
one of ``"done"``, ``"skipped"``.
|
||||
"""
|
||||
cfg = config.cfg
|
||||
conn = database.get_connection(cfg.db_path)
|
||||
staged_path = Path(job_row["staging_path"])
|
||||
|
||||
try:
|
||||
# --- Determine document type and language -------------------------
|
||||
filename = job_row["filename"]
|
||||
forced_type = job_row["doc_type"]
|
||||
|
||||
if staged_path.suffix == ".note":
|
||||
doc_type = "note"
|
||||
language = None
|
||||
elif forced_type:
|
||||
doc_type = forced_type
|
||||
language = None
|
||||
else:
|
||||
doc_type, language = detector.detect_type(Path(filename))
|
||||
|
||||
# --- Chunk the content --------------------------------------------
|
||||
if doc_type == "note":
|
||||
text = staged_path.read_text(encoding="utf-8")
|
||||
from kb.ingest.note import chunk_note
|
||||
chunks = chunk_note(text)
|
||||
elif doc_type == "pdf":
|
||||
from kb.ingest.docling_pipeline import chunk_document
|
||||
chunks = chunk_document(staged_path, cfg.ingest_device)
|
||||
elif doc_type == "markdown":
|
||||
text = staged_path.read_text(encoding="utf-8")
|
||||
from kb.ingest.markdown import chunk_markdown
|
||||
chunks = chunk_markdown(text)
|
||||
elif doc_type == "code":
|
||||
text = staged_path.read_text(encoding="utf-8")
|
||||
if not language:
|
||||
_, language = detector.detect_type(Path(filename))
|
||||
from kb.ingest.code import chunk_code
|
||||
chunks = chunk_code(text, language)
|
||||
else:
|
||||
raise ValueError(f"Unsupported doc_type: {doc_type}")
|
||||
|
||||
# --- Duplicate detection via content hash -------------------------
|
||||
raw_bytes = staged_path.read_bytes()
|
||||
content_hash = hashlib.sha256(raw_bytes).hexdigest()
|
||||
|
||||
if database.hash_exists(conn, content_hash):
|
||||
logger.info("Duplicate detected for job %d, skipping", job_row["id"])
|
||||
return ("skipped", None, 0)
|
||||
|
||||
# --- Persist document, chunks, and embeddings ---------------------
|
||||
title = job_row["title"] or filename
|
||||
doc_id = database.insert_document(
|
||||
conn,
|
||||
title=title,
|
||||
source_path=str(staged_path),
|
||||
content_hash=content_hash,
|
||||
doc_type=doc_type,
|
||||
language=language,
|
||||
)
|
||||
|
||||
chunk_texts = [c if isinstance(c, str) else c["text"] for c in chunks]
|
||||
vectors = embeddings.embed_texts(chunk_texts)
|
||||
|
||||
for idx, (chunk_text, vector) in enumerate(zip(chunk_texts, vectors)):
|
||||
metadata = None
|
||||
if not isinstance(chunks[idx], str):
|
||||
metadata = {
|
||||
k: v for k, v in chunks[idx].items() if k != "text"
|
||||
} or None
|
||||
chunk_id = database.insert_chunk(
|
||||
conn,
|
||||
document_id=doc_id,
|
||||
chunk_index=idx,
|
||||
text=chunk_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
database.insert_embedding(conn, chunk_id, vector)
|
||||
|
||||
# --- Tags ---------------------------------------------------------
|
||||
tags = json.loads(job_row["tags_json"] or "[]")
|
||||
if tags:
|
||||
database.tag_document(conn, doc_id, tags)
|
||||
|
||||
conn.commit()
|
||||
return ("done", doc_id, len(chunk_texts))
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
staging.cleanup(staged_path)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Engine entry point — FastAPI server with eager model loading."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
_version_file = Path(__file__).parent / "VERSION"
|
||||
__version__ = _version_file.read_text().strip() if _version_file.exists() else "dev"
|
||||
|
||||
from kb.config import cfg
|
||||
from kb.embeddings import load_model
|
||||
from kb.database import get_connection, init_schema
|
||||
from kb.worker import ingestion_worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger("kb.engine")
|
||||
|
||||
# Track readiness for health endpoint
|
||||
ready = False
|
||||
worker_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global ready, worker_task
|
||||
|
||||
# Set HF cache before any model imports
|
||||
os.environ["HF_HOME"] = str(cfg.hf_cache)
|
||||
|
||||
log.info("Starting engine...")
|
||||
cfg.ensure_dirs()
|
||||
|
||||
# Initialise database
|
||||
conn = get_connection(cfg.db_path)
|
||||
model_dim = load_model(cfg.model, cfg.device)
|
||||
init_schema(conn, model_dim)
|
||||
conn.close()
|
||||
|
||||
# Start background ingestion worker
|
||||
worker_task = asyncio.create_task(ingestion_worker())
|
||||
|
||||
ready = True
|
||||
log.info("Engine ready — model: %s, device: %s", cfg.model, cfg.device)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
ready = False
|
||||
if worker_task:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
log.info("Engine stopped.")
|
||||
|
||||
|
||||
app = FastAPI(title="kb-engine", version=__version__, lifespan=lifespan)
|
||||
|
||||
# Import routes after app is created
|
||||
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth # noqa: E402, F401
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host=cfg.host, port=cfg.port, log_level="info")
|
||||
@@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "kb-engine"
|
||||
version = "2.0.0"
|
||||
description = "Knowledge base engine — FastAPI server with hybrid search and async ingestion"
|
||||
requires-python = ">=3.11"
|
||||
license = "MIT"
|
||||
dependencies = [
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.30",
|
||||
"python-multipart>=0.0.9",
|
||||
"click>=8.1",
|
||||
"pyyaml>=6.0",
|
||||
"sentence-transformers>=3.0",
|
||||
"sqlite-vec>=0.1.1",
|
||||
"docling>=2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
"httpx>=0.27",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["kb*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
Reference in New Issue
Block a user