Initial MVP

This commit is contained in:
2026-03-23 20:38:42 +00:00
commit f245c24928
57 changed files with 6812 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
"""kb-search: CLI knowledge base with hybrid search."""
__version__ = "0.1.0"
+616
View File
@@ -0,0 +1,616 @@
"""CLI entry point for kb-search."""
import click
@click.group()
@click.version_option(package_name="kb-search")
def main():
"""Personal knowledge base with hybrid search."""
@main.command()
@click.option("--model", default=None, help="Embedding model name (HuggingFace).")
@click.option("--status", is_flag=True, help="Show initialisation status.")
def init(model, status):
"""Initialise the knowledge base and download models."""
from kb_search.config import get_data_dir, get_db_path, load_config
from kb_search.database import get_connection, get_db_config, init_schema, run_migrations, set_db_config
from kb_search.embeddings import download_model, get_model_dim
cfg = load_config()
data_dir = get_data_dir(cfg)
db_path = get_db_path(cfg)
model_name = model or cfg["embedding"]["model"]
if status:
click.echo(f"Data directory: {data_dir} ({'exists' if data_dir.exists() else 'not created'})")
click.echo(f"Database: {db_path} ({'exists' if db_path.exists() else 'not created'})")
if db_path.exists():
conn = get_connection(db_path)
db_model = get_db_config(conn, "model_name", "not set")
db_dim = get_db_config(conn, "embedding_dim", "not set")
click.echo(f"Model: {db_model} ({db_dim} dim)")
conn.close()
else:
click.echo(f"Model: {model_name} (not yet initialised)")
return
# Create data directory
data_dir.mkdir(parents=True, exist_ok=True)
# Download model and get dimension
download_model(model_name)
dim = get_model_dim(model_name)
# Initialise database
conn = get_connection(db_path)
init_schema(conn, embedding_dim=dim)
run_migrations(conn)
set_db_config(conn, "model_name", model_name)
set_db_config(conn, "embedding_dim", str(dim))
conn.close()
click.echo(f"Knowledge base initialised at {data_dir}")
click.echo(f"Model: {model_name} ({dim} dimensions)")
click.echo("Ready! Add documents with `kb add`.")
@main.command()
@click.argument("path", required=False)
@click.option("--note", default=None, help="Add an inline text note.")
@click.option("--title", default=None, help="Title for the note.")
@click.option("--tags", default=None, help="Comma-separated tags.")
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Force document type.")
@click.option("--language", default=None, type=click.Choice(["python", "bash", "go"]), help="Force code language.")
@click.option("--recursive", is_flag=True, help="Recurse into directories.")
@click.option("--workers", default=None, type=int, help="Number of parallel workers.")
def add(path, note, title, tags, doc_type, language, recursive, workers):
"""Add documents to the knowledge base."""
import hashlib
from pathlib import Path as P
from kb_search.config import get_db_path, load_config
from kb_search.database import (
get_connection, hash_exists, insert_chunk, insert_document,
insert_embedding, tag_document,
)
from kb_search.embeddings import check_model_binding, embed_texts
from kb_search.ingest.detector import detect_type, is_supported
from kb_search.ingest.note import auto_title, chunk_note
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
check_model_binding(conn, cfg)
model_name = cfg["embedding"]["model"]
tag_list = [t.strip() for t in tags.split(",")] if tags else []
if note:
# Inline note
content_hash = hashlib.sha256(note.encode()).hexdigest()
if hash_exists(conn, content_hash):
click.echo("Skipped: note (already indexed)")
conn.close()
return
note_title = title or auto_title(note)
chunks = chunk_note(note)
doc_id = insert_document(conn, note_title, None, content_hash, "note")
for c in chunks:
chunk_id = insert_chunk(conn, doc_id, c["chunk_index"], c["text"], metadata=c["metadata"])
emb = embed_texts(model_name, [c["text"]], prefix=cfg["embedding"].get("passage_prefix", ""))
insert_embedding(conn, chunk_id, emb[0])
if tag_list:
tag_document(conn, doc_id, tag_list)
conn.commit()
conn.close()
click.echo(f"Added note: {note_title}")
return
if not path:
raise click.ClickException("Provide a file/directory path or use --note.")
file_path = P(path).expanduser().resolve()
if file_path.is_dir():
_add_directory(conn, file_path, cfg, model_name, tag_list, doc_type, language,
recursive, workers)
elif file_path.is_file():
result = _add_single_file(conn, file_path, cfg, model_name, tag_list, doc_type, language)
click.echo(result)
else:
raise click.ClickException(f"Path not found: {file_path}")
conn.close()
def _add_single_file(conn, file_path, cfg, model_name, tag_list, force_type, force_language):
"""Add a single file. Returns a status message."""
import hashlib
from kb_search.database import (
hash_exists, insert_chunk, insert_document, insert_embedding, tag_document,
)
from kb_search.embeddings import embed_texts
from kb_search.ingest.detector import detect_type
# Dedup check
content_hash = hashlib.sha256(file_path.read_bytes()).hexdigest()
if hash_exists(conn, content_hash):
return f"Skipped: {file_path.name} (already indexed)"
doc_type, language = detect_type(file_path, force_type, force_language)
chunks = _get_chunks(file_path, doc_type, language, cfg)
if not chunks:
return f"Skipped: {file_path.name} (no content extracted)"
title = file_path.stem
doc_id = insert_document(conn, title, str(file_path), content_hash, doc_type,
language=language)
# Embed all chunks in one batch
texts = [c["text"] for c in chunks]
prefix = cfg["embedding"].get("passage_prefix", "")
embeddings = embed_texts(model_name, texts, prefix=prefix)
for c, emb in zip(chunks, embeddings):
chunk_id = insert_chunk(conn, doc_id, c["chunk_index"], c["text"],
token_count=c.get("token_count"),
metadata=c.get("metadata", {}))
insert_embedding(conn, chunk_id, emb)
if tag_list:
tag_document(conn, doc_id, tag_list)
conn.commit()
return f"Added: {file_path.name} ({len(chunks)} chunks)"
def _get_chunks(file_path, doc_type, language, cfg):
"""Route to the correct chunking pipeline."""
if doc_type == "pdf":
from kb_search.ingest.docling import chunk_document
return chunk_document(file_path, cfg)
elif doc_type == "markdown":
from kb_search.ingest.markdown import chunk_markdown
text = file_path.read_text(errors="replace")
return chunk_markdown(text, cfg)
elif doc_type == "code":
from kb_search.ingest.code import chunk_code
text = file_path.read_text(errors="replace")
return chunk_code(text, language, cfg)
elif doc_type == "note":
from kb_search.ingest.note import chunk_note
text = file_path.read_text(errors="replace")
return chunk_note(text)
return []
def _add_directory(conn, dir_path, cfg, model_name, tag_list, force_type, force_language,
recursive, workers):
"""Add all supported files in a directory."""
from pathlib import Path as P
from kb_search.ingest.detector import is_supported
pattern = "**/*" if recursive else "*"
files = sorted(f for f in dir_path.glob(pattern) if f.is_file() and is_supported(f))
if not files:
click.echo(f"No supported files found in {dir_path}")
return
added = 0
skipped = 0
failed = 0
error_log = cfg.get("data_dir", "~/.kb")
from kb_search.config import get_data_dir
error_log_path = get_data_dir(cfg) / "ingest-errors.log"
with click.progressbar(files, label="Ingesting", show_pos=True) as bar:
for f in bar:
try:
result = _add_single_file(conn, f, cfg, model_name, tag_list,
force_type, force_language)
if "Skipped" in result:
skipped += 1
else:
added += 1
except Exception as e:
failed += 1
with open(error_log_path, "a") as log:
log.write(f"{f}: {e}\n")
click.echo(f"\nAdded {added} documents. {failed} failed. {skipped} skipped (already indexed).")
@main.command()
@click.argument("query")
@click.option("--top", default=None, type=int, help="Number of results.")
@click.option("--tags", default=None, help="Filter by tags (comma-separated).")
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Filter by document type.")
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
@click.option("--fts-only", is_flag=True, help="Full-text search only.")
@click.option("--vec-only", is_flag=True, help="Vector search only.")
@click.option("--threshold", default=None, type=float, help="Minimum score cutoff.")
def search(query, top, tags, doc_type, fmt, fts_only, vec_only, threshold):
"""Search the knowledge base."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection, get_db_config
from kb_search.embeddings import check_model_binding
from kb_search.search import hybrid_search
from kb_search.output import format_search_results
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
check_model_binding(conn, cfg)
model_name = get_db_config(conn, "model_name") or cfg["embedding"]["model"]
top = top or cfg["search"]["default_top"]
fmt = fmt or cfg["search"]["default_format"]
tag_list = [t.strip() for t in tags.split(",")] if tags else None
results = hybrid_search(
conn, query, model_name, cfg,
top=top, tags=tag_list, doc_type=doc_type,
fts_only=fts_only, vec_only=vec_only, threshold=threshold,
)
conn.close()
click.echo(format_search_results(results, fmt))
@main.command("list")
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Filter by document type.")
@click.option("--tags", default=None, help="Filter by tags (comma-separated).")
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
def list_docs(doc_type, tags, fmt):
"""List indexed documents."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection
from kb_search.output import format_document_list
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
fmt = fmt or cfg["search"]["default_format"]
sql = """
SELECT d.id, d.title, d.doc_type as type, d.created_at,
COUNT(c.id) as chunk_count
FROM documents d
LEFT JOIN chunks c ON d.id = c.document_id
"""
joins = []
where = []
params = []
if doc_type:
where.append("d.doc_type = ?")
params.append(doc_type)
tag_list = [t.strip().lower() for t in tags.split(",")] if tags else []
for i, tag in enumerate(tag_list):
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
where.append(f"t{i}.name = ?")
params.append(tag)
sql += " " + " ".join(joins)
if where:
sql += " WHERE " + " AND ".join(where)
sql += " GROUP BY d.id ORDER BY d.created_at DESC"
rows = conn.execute(sql, params).fetchall()
docs = []
for row in rows:
tag_rows = conn.execute("""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""", (row["id"],)).fetchall()
docs.append({
"id": row["id"],
"title": row["title"],
"type": row["type"],
"tags": [r["name"] for r in tag_rows],
"chunk_count": row["chunk_count"],
"created_at": row["created_at"],
})
conn.close()
click.echo(format_document_list(docs, fmt))
@main.command()
@click.argument("doc_id", type=int)
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
def info(doc_id, fmt):
"""Show document details."""
import json as jsonlib
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection
from kb_search.output import format_doc_info
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
fmt = fmt or cfg["search"]["default_format"]
row = conn.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
if not row:
raise click.ClickException(f"Document not found: {doc_id}")
chunks = conn.execute(
"SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index",
(doc_id,),
).fetchall()
tag_rows = conn.execute("""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""", (doc_id,)).fetchall()
info_data = {
"id": row["id"],
"title": row["title"],
"type": row["doc_type"],
"language": row["language"],
"path": row["source_path"],
"content_hash": row["content_hash"],
"created_at": row["created_at"],
"tags": [r["name"] for r in tag_rows],
"chunk_count": len(chunks),
"chunks": [{"chunk_index": c["chunk_index"], "text": c["text"]} for c in chunks],
}
conn.close()
click.echo(format_doc_info(info_data, fmt))
@main.command()
@click.argument("doc_id", type=int)
@click.option("--yes", is_flag=True, help="Skip confirmation prompt.")
def remove(doc_id, yes):
"""Remove a document from the knowledge base."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
row = conn.execute("SELECT id, title FROM documents WHERE id = ?", (doc_id,)).fetchone()
if not row:
raise click.ClickException(f"Document not found: {doc_id}")
chunk_count = conn.execute(
"SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)
).fetchone()[0]
if not yes:
if not click.confirm(f"Remove '{row['title']}' and its {chunk_count} chunks?"):
click.echo("Cancelled.")
conn.close()
return
# Delete vectors for this document's chunks
conn.execute("""
DELETE FROM chunks_vec WHERE chunk_id IN (
SELECT id FROM chunks WHERE document_id = ?
)
""", (doc_id,))
# Cascade handles chunks, document_tags
conn.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
conn.commit()
conn.close()
click.echo(f"Removed '{row['title']}' ({chunk_count} chunks).")
@main.command("tags")
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
def list_tags(fmt):
"""List all tags with document counts."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection
from kb_search.output import format_tags
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
fmt = fmt or cfg["search"]["default_format"]
rows = conn.execute("""
SELECT t.name, COUNT(dt.document_id) as count
FROM tags t
LEFT JOIN document_tags dt ON t.id = dt.tag_id
GROUP BY t.id
ORDER BY count DESC, t.name
""").fetchall()
tags = [{"name": r["name"], "count": r["count"]} for r in rows]
conn.close()
click.echo(format_tags(tags, fmt))
@main.command()
@click.argument("doc_id", type=int)
@click.option("--add", "add_tags", default=None, help="Tags to add (comma-separated).")
@click.option("--remove", "remove_tags", default=None, help="Tags to remove (comma-separated).")
def tag(doc_id, add_tags, remove_tags):
"""Manage tags on a document."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection, tag_document, untag_document
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
row = conn.execute("SELECT id, title FROM documents WHERE id = ?", (doc_id,)).fetchone()
if not row:
raise click.ClickException(f"Document not found: {doc_id}")
if add_tags:
tags = [t.strip() for t in add_tags.split(",")]
tag_document(conn, doc_id, tags)
conn.commit()
click.echo(f"Added tags [{', '.join(tags)}] to '{row['title']}'")
if remove_tags:
tags = [t.strip() for t in remove_tags.split(",")]
untag_document(conn, doc_id, tags)
conn.commit()
click.echo(f"Removed tags [{', '.join(tags)}] from '{row['title']}'")
conn.close()
@main.command()
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
def status(fmt):
"""Show knowledge base status and statistics."""
from kb_search.config import get_db_path, load_config
from kb_search.database import get_connection, get_db_config
from kb_search.output import format_status
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
fmt = fmt or cfg["search"]["default_format"]
doc_counts = {}
for row in conn.execute("SELECT doc_type, COUNT(*) as cnt FROM documents GROUP BY doc_type").fetchall():
doc_counts[row["doc_type"]] = row["cnt"]
total_docs = sum(doc_counts.values())
total_chunks = conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
db_size = db_path.stat().st_size
status_data = {
"model_name": get_db_config(conn, "model_name", "not set"),
"embedding_dim": get_db_config(conn, "embedding_dim", "not set"),
"schema_version": get_db_config(conn, "schema_version", "not set"),
"db_size_bytes": db_size,
"documents": doc_counts,
"total_documents": total_docs,
"total_chunks": total_chunks,
}
conn.close()
click.echo(format_status(status_data, fmt))
@main.command()
@click.option("--model", default=None, help="Switch to a different embedding model.")
def reindex(model):
"""Re-embed all chunks (optionally with a new model)."""
import struct
from kb_search.config import get_db_path, load_config
from kb_search.database import (
get_connection, get_db_config, insert_embedding,
recreate_vec_table, set_db_config,
)
from kb_search.embeddings import download_model, embed_texts, get_model_dim
cfg = load_config()
db_path = get_db_path(cfg)
if not db_path.exists():
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
conn = get_connection(db_path)
model_name = model or get_db_config(conn, "model_name") or cfg["embedding"]["model"]
# Download model if switching
if model:
download_model(model_name)
dim = get_model_dim(model_name)
# Get all chunks
rows = conn.execute("SELECT id, text FROM chunks ORDER BY id").fetchall()
if not rows:
click.echo("No chunks to re-embed.")
conn.close()
return
click.echo(f"Re-embedding {len(rows)} chunks with '{model_name}' ({dim} dim)...")
# Embed in batches
batch_size = 256
all_ids = [r["id"] for r in rows]
all_texts = [r["text"] for r in rows]
prefix = cfg["embedding"].get("passage_prefix", "")
all_embeddings = []
with click.progressbar(range(0, len(all_texts), batch_size), label="Embedding") as bar:
for i in bar:
batch = all_texts[i:i + batch_size]
batch_embs = embed_texts(model_name, batch, prefix=prefix)
all_embeddings.extend(batch_embs)
# Atomically replace vectors
recreate_vec_table(conn, dim)
for chunk_id, emb in zip(all_ids, all_embeddings):
insert_embedding(conn, chunk_id, emb)
set_db_config(conn, "model_name", model_name)
set_db_config(conn, "embedding_dim", str(dim))
conn.commit()
conn.close()
click.echo(f"Reindex complete. {len(rows)} chunks embedded with '{model_name}'.")
@main.group(invoke_without_command=True)
@click.pass_context
def config(ctx):
"""View or modify configuration."""
if ctx.invoked_subcommand is None:
from kb_search.config import config_with_sources
entries = config_with_sources()
max_key = max(len(k) for k, _, _ in entries)
max_val = max(len(v) for _, v, _ in entries)
for key, value, source in entries:
click.echo(f" {key:<{max_key}} {value:<{max_val}} ({source})")
@config.command("set")
@click.argument("key")
@click.argument("value")
def config_set(key, value):
"""Set a configuration value."""
from kb_search.config import get_config_path, load_config, save_config_value
cfg = load_config()
path = get_config_path(cfg)
save_config_value(path, key, value)
click.echo(f"Set {key} = {value} in {path}")
main.add_command(config)
+195
View File
@@ -0,0 +1,195 @@
"""Configuration loading with YAML + ENV + defaults."""
import os
from copy import deepcopy
from pathlib import Path
import yaml
DEFAULTS = {
"data_dir": "~/.kb",
"embedding": {
"model": "all-MiniLM-L6-v2",
"query_prefix": "",
"passage_prefix": "",
},
"search": {
"default_top": 10,
"default_format": "json",
"rrf_k": 60,
},
"chunking": {
"defaults": {
"max_tokens": 512,
"overlap_tokens": 50,
},
"pdf": {
"strategy": "hierarchy",
"max_tokens": 1024,
},
"markdown": {
"strategy": "header",
"min_tokens": 50,
"max_tokens": 1024,
},
"code": {
"strategy": "ast",
"include_context": True,
"max_tokens": 1024,
},
"note": {
"strategy": "whole",
},
},
"ingestion": {
"workers": 4,
"batch_size": 50,
"enable_ocr": "auto",
},
}
# ENV variable mapping: ENV_NAME -> config dotted key
ENV_MAP = {
"KB_DATA_DIR": "data_dir",
"KB_MODEL": "embedding.model",
"KB_DEFAULT_TOP": "search.default_top",
"KB_DEFAULT_FORMAT": "search.default_format",
}
# Type coercions for ENV values
ENV_TYPES = {
"search.default_top": int,
}
def _deep_merge(base: dict, override: dict) -> dict:
"""Deep merge override into base, returning a new dict."""
result = deepcopy(base)
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = deepcopy(value)
return result
def _set_nested(d: dict, dotted_key: str, value):
"""Set a value in a nested dict using a dotted key path."""
keys = dotted_key.split(".")
for key in keys[:-1]:
d = d.setdefault(key, {})
d[keys[-1]] = value
def _get_nested(d: dict, dotted_key: str, default=None):
"""Get a value from a nested dict using a dotted key path."""
keys = dotted_key.split(".")
for key in keys:
if not isinstance(d, dict) or key not in d:
return default
d = d[key]
return d
def get_data_dir(cfg: dict) -> Path:
"""Resolve the data directory from config."""
return Path(cfg["data_dir"]).expanduser()
def get_config_path(cfg: dict) -> Path:
"""Path to the YAML config file."""
return get_data_dir(cfg) / "config.yaml"
def get_db_path(cfg: dict) -> Path:
"""Path to the SQLite database."""
return get_data_dir(cfg) / "kb.db"
def load_config(config_path: Path | None = None) -> dict:
"""Load config with precedence: ENV > YAML > defaults.
CLI flags are applied by the caller after this returns.
"""
cfg = deepcopy(DEFAULTS)
# Determine config file path (ENV can override data_dir which affects path)
if config_path is None:
data_dir = os.environ.get("KB_DATA_DIR", DEFAULTS["data_dir"])
config_path = Path(data_dir).expanduser() / "config.yaml"
# Load YAML if it exists
if config_path.is_file():
with open(config_path) as f:
yaml_cfg = yaml.safe_load(f) or {}
cfg = _deep_merge(cfg, yaml_cfg)
# Apply ENV overrides
for env_name, dotted_key in ENV_MAP.items():
env_val = os.environ.get(env_name)
if env_val is not None:
coerce = ENV_TYPES.get(dotted_key, str)
_set_nested(cfg, dotted_key, coerce(env_val))
return cfg
def save_config_value(config_path: Path, dotted_key: str, value: str):
"""Set a single value in the YAML config file."""
config_path.parent.mkdir(parents=True, exist_ok=True)
existing = {}
if config_path.is_file():
with open(config_path) as f:
existing = yaml.safe_load(f) or {}
# Try numeric coercion
try:
value = int(value)
except ValueError:
try:
value = float(value)
except ValueError:
if value.lower() in ("true", "false"):
value = value.lower() == "true"
_set_nested(existing, dotted_key, value)
with open(config_path, "w") as f:
yaml.dump(existing, f, default_flow_style=False, sort_keys=False)
def config_with_sources(config_path: Path | None = None) -> list[tuple[str, str, str]]:
"""Return a flat list of (dotted_key, value, source) tuples for display."""
if config_path is None:
data_dir = os.environ.get("KB_DATA_DIR", DEFAULTS["data_dir"])
config_path = Path(data_dir).expanduser() / "config.yaml"
yaml_cfg = {}
if config_path.is_file():
with open(config_path) as f:
yaml_cfg = yaml.safe_load(f) or {}
# Build reverse ENV map for source detection
env_keys = {v: k for k, v in ENV_MAP.items()}
def _flatten(d, prefix=""):
items = []
for k, v in d.items():
key = f"{prefix}.{k}" if prefix else k
if isinstance(v, dict):
items.extend(_flatten(v, key))
else:
# Determine source
env_name = env_keys.get(key)
if env_name and os.environ.get(env_name) is not None:
source = f"env ({env_name})"
elif _get_nested(yaml_cfg, key) is not None:
source = "config.yaml"
else:
source = "default"
items.append((key, str(v), source))
return items
cfg = load_config(config_path)
return _flatten(cfg)
+229
View File
@@ -0,0 +1,229 @@
"""SQLite database management with FTS5 and sqlite-vec."""
import json
import sqlite3
from pathlib import Path
import sqlite_vec
SCHEMA_VERSION = 1
def get_connection(db_path: Path) -> sqlite3.Connection:
"""Open a SQLite connection with sqlite-vec loaded."""
conn = sqlite3.connect(str(db_path))
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
return conn
def init_schema(conn: sqlite3.Connection, embedding_dim: int):
"""Create all tables, FTS, vector index, and triggers."""
conn.executescript(f"""
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
source_path TEXT,
content_hash TEXT NOT NULL,
doc_type TEXT NOT NULL CHECK(doc_type IN ('pdf','markdown','code','note')),
language TEXT,
created_at TEXT DEFAULT (datetime('now')),
metadata TEXT DEFAULT '{{}}'
);
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL,
text TEXT NOT NULL,
token_count INTEGER,
metadata TEXT DEFAULT '{{}}',
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL
);
CREATE TABLE IF NOT EXISTS document_tags (
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
tag_id INTEGER REFERENCES tags(id) ON DELETE CASCADE,
PRIMARY KEY (document_id, tag_id)
);
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id);
CREATE INDEX IF NOT EXISTS idx_documents_content_hash ON documents(content_hash);
""")
# FTS5 virtual table (content-sync with chunks)
conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
content='chunks',
content_rowid='id',
tokenize='porter unicode61'
)
""")
# FTS sync triggers
conn.executescript("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
END;
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES('delete', old.id, old.text);
END;
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES('delete', old.id, old.text);
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
END;
""")
# Vector table
conn.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_vec USING vec0(
chunk_id INTEGER PRIMARY KEY,
embedding FLOAT[{embedding_dim}]
)
""")
conn.commit()
def get_db_config(conn: sqlite3.Connection, key: str, default: str | None = None) -> str | None:
"""Get a value from the config table."""
row = conn.execute("SELECT value FROM config WHERE key = ?", (key,)).fetchone()
return row["value"] if row else default
def set_db_config(conn: sqlite3.Connection, key: str, value: str):
"""Set a value in the config table."""
conn.execute(
"INSERT INTO config (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
(key, value, value),
)
conn.commit()
def check_schema_version(conn: sqlite3.Connection) -> int | None:
"""Check the current schema version. Returns None if not initialised."""
try:
return int(get_db_config(conn, "schema_version", "0"))
except Exception:
return None
def run_migrations(conn: sqlite3.Connection):
"""Run pending schema migrations."""
current = check_schema_version(conn) or 0
# Migration registry: version -> callable
migrations: dict[int, callable] = {
# Future migrations go here:
# 2: _migrate_v2,
}
for version in sorted(migrations.keys()):
if current < version:
migrations[version](conn)
set_db_config(conn, "schema_version", str(version))
if current < SCHEMA_VERSION:
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
def recreate_vec_table(conn: sqlite3.Connection, embedding_dim: int):
"""Drop and recreate the vector table with a new dimension."""
conn.execute("DROP TABLE IF EXISTS chunks_vec")
conn.execute(f"""
CREATE VIRTUAL TABLE chunks_vec USING vec0(
chunk_id INTEGER PRIMARY KEY,
embedding FLOAT[{embedding_dim}]
)
""")
conn.commit()
def insert_document(conn: sqlite3.Connection, title: str, source_path: str | None,
content_hash: str, doc_type: str, language: str | None = None,
metadata: dict | None = None) -> int:
"""Insert a document and return its ID."""
cur = conn.execute(
"INSERT INTO documents (title, source_path, content_hash, doc_type, language, metadata) "
"VALUES (?, ?, ?, ?, ?, ?)",
(title, source_path, content_hash, doc_type, language, json.dumps(metadata or {})),
)
return cur.lastrowid
def insert_chunk(conn: sqlite3.Connection, document_id: int, chunk_index: int,
text: str, token_count: int | None = None,
metadata: dict | None = None) -> int:
"""Insert a chunk and return its ID."""
cur = conn.execute(
"INSERT INTO chunks (document_id, chunk_index, text, token_count, metadata) "
"VALUES (?, ?, ?, ?, ?)",
(document_id, chunk_index, text, token_count, json.dumps(metadata or {})),
)
return cur.lastrowid
def insert_embedding(conn: sqlite3.Connection, chunk_id: int, embedding: list[float]):
"""Insert a chunk embedding into the vector table."""
import struct
blob = struct.pack(f"{len(embedding)}f", *embedding)
conn.execute(
"INSERT INTO chunks_vec (chunk_id, embedding) VALUES (?, ?)",
(chunk_id, blob),
)
def hash_exists(conn: sqlite3.Connection, content_hash: str) -> bool:
"""Check if a document with this content hash already exists."""
row = conn.execute(
"SELECT 1 FROM documents WHERE content_hash = ? LIMIT 1", (content_hash,)
).fetchone()
return row is not None
def get_or_create_tag(conn: sqlite3.Connection, name: str) -> int:
"""Get or create a tag, return its ID. Tags are stored lowercase."""
name = name.strip().lower()
row = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()
if row:
return row["id"]
cur = conn.execute("INSERT INTO tags (name) VALUES (?)", (name,))
return cur.lastrowid
def tag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]):
"""Associate tags with a document."""
for name in tag_names:
tag_id = get_or_create_tag(conn, name)
conn.execute(
"INSERT OR IGNORE INTO document_tags (document_id, tag_id) VALUES (?, ?)",
(document_id, tag_id),
)
def untag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]):
"""Remove tag associations from a document."""
for name in tag_names:
name = name.strip().lower()
conn.execute(
"DELETE FROM document_tags WHERE document_id = ? AND tag_id = "
"(SELECT id FROM tags WHERE name = ?)",
(document_id, name),
)
+67
View File
@@ -0,0 +1,67 @@
"""Embedding model management — download, load, and inference via ONNX."""
import click
from pathlib import Path
_model_instance = None
_model_name = None
def load_model(model_name: str):
"""Load a sentence-transformers model with ONNX backend. Caches in-process."""
global _model_instance, _model_name
if _model_instance is not None and _model_name == model_name:
return _model_instance
from sentence_transformers import SentenceTransformer
click.echo(f"Loading model '{model_name}'...")
try:
_model_instance = SentenceTransformer(model_name, backend="onnx")
except Exception:
# Fallback: some models may not have pre-exported ONNX. Let sentence-transformers export.
click.echo("Optimising model for ONNX inference (one-time)...")
_model_instance = SentenceTransformer(model_name, backend="onnx")
_model_name = model_name
return _model_instance
def get_model_dim(model_name: str) -> int:
"""Get the embedding dimension for a model."""
model = load_model(model_name)
return model.get_sentence_embedding_dimension()
def embed_texts(model_name: str, texts: list[str],
prefix: str = "", show_progress: bool = False) -> list[list[float]]:
"""Embed a list of texts, returning float vectors."""
model = load_model(model_name)
if prefix:
texts = [prefix + t for t in texts]
embeddings = model.encode(texts, show_progress_bar=show_progress, convert_to_numpy=True)
return [e.tolist() for e in embeddings]
def download_model(model_name: str):
"""Pre-download a model (for kb init)."""
click.echo(f"Downloading embedding model '{model_name}'...")
load_model(model_name)
click.echo("Embedding model ready.")
def check_model_binding(conn, cfg: dict):
"""Verify the loaded model matches what the DB expects. Raises on mismatch."""
from kb_search.database import get_db_config
db_model = get_db_config(conn, "model_name")
if db_model is None:
return # Not yet initialised
config_model = cfg["embedding"]["model"]
if db_model != config_model:
db_dim = get_db_config(conn, "embedding_dim", "?")
raise click.ClickException(
f"Model mismatch: DB uses '{db_model}' ({db_dim} dim) but config specifies "
f"'{config_model}'. Run `kb reindex --model {config_model}` to switch models."
)
View File
+244
View File
@@ -0,0 +1,244 @@
"""Code ingestion — AST/regex-based splitting for Python, Bash, Go."""
import ast
import re
def chunk_code(text: str, language: str | None, cfg: dict) -> list[dict]:
"""Split code at function/class boundaries."""
chunking_cfg = cfg.get("chunking", {}).get("code", {})
strategy = chunking_cfg.get("strategy", "ast")
include_context = chunking_cfg.get("include_context", True)
if strategy == "fixed":
return _fixed_chunk(text, chunking_cfg)
if language == "python":
chunks = _chunk_python(text, include_context)
elif language in ("bash", "sh"):
chunks = _chunk_bash(text, include_context)
elif language == "go":
chunks = _chunk_go(text, include_context)
else:
chunks = []
if not chunks:
return _fixed_chunk(text, chunking_cfg)
for i, c in enumerate(chunks):
c["chunk_index"] = i
return chunks
def _chunk_python(text: str, include_context: bool) -> list[dict]:
"""Split Python using stdlib ast module."""
try:
tree = ast.parse(text)
except SyntaxError:
return []
lines = text.splitlines(keepends=True)
chunks = []
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.ClassDef):
class_lines = _get_node_source(lines, node)
class_docstring = ast.get_docstring(node) or ""
# Each method becomes a chunk
methods = [n for n in ast.iter_child_nodes(node) if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
if methods:
for method in methods:
method_src = _get_node_source(lines, method)
if include_context and class_docstring:
context = f"class {node.name}:\n \"\"\"{class_docstring}\"\"\"\n\n"
chunk_text = context + method_src
elif include_context:
chunk_text = f"class {node.name}:\n\n" + method_src
else:
chunk_text = method_src
chunks.append({
"text": chunk_text,
"metadata": {
"symbol_name": f"{node.name}.{method.name}",
"line_start": method.lineno,
"line_end": method.end_lineno,
},
})
else:
# Class with no methods — single chunk
chunks.append({
"text": class_lines,
"metadata": {
"symbol_name": node.name,
"line_start": node.lineno,
"line_end": node.end_lineno,
},
})
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
func_src = _get_node_source(lines, node)
chunks.append({
"text": func_src,
"metadata": {
"symbol_name": node.name,
"line_start": node.lineno,
"line_end": node.end_lineno,
},
})
return chunks
def _get_node_source(lines: list[str], node) -> str:
"""Extract source code for an AST node, including decorators."""
start = node.lineno - 1
# Include decorators
if hasattr(node, "decorator_list") and node.decorator_list:
start = node.decorator_list[0].lineno - 1
end = node.end_lineno
return "".join(lines[start:end]).rstrip()
def _chunk_bash(text: str, include_context: bool) -> list[dict]:
"""Split Bash at function boundaries using regex."""
# Match: function name() { or name() {
func_pattern = re.compile(
r"^((?:#[^\n]*\n)*)?" # Optional preceding comment block
r"(?:function\s+(\w+)\s*\(\s*\)\s*\{|(\w+)\s*\(\s*\)\s*\{)",
re.MULTILINE,
)
chunks = []
matches = list(func_pattern.finditer(text))
if not matches:
return []
for i, match in enumerate(matches):
start = match.start()
# Find end: next function or end of file
if i + 1 < len(matches):
end = matches[i + 1].start()
else:
end = len(text)
func_name = match.group(2) or match.group(3)
chunk_text = text[start:end].rstrip()
chunks.append({
"text": chunk_text,
"metadata": {
"symbol_name": func_name,
},
})
return chunks
def _chunk_go(text: str, include_context: bool) -> list[dict]:
"""Split Go at func declarations using regex."""
func_pattern = re.compile(
r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(",
re.MULTILINE,
)
chunks = []
matches = list(func_pattern.finditer(text))
if not matches:
return []
for i, match in enumerate(matches):
start = match.start()
# Include preceding comment block
before = text[:start]
comment_lines = []
for line in reversed(before.splitlines()):
stripped = line.strip()
if stripped.startswith("//") or not stripped:
comment_lines.insert(0, line)
else:
break
if comment_lines:
comment_text = "\n".join(comment_lines).strip()
if comment_text:
start = text.rfind(comment_lines[0], 0, start)
# Find end: next func or end of file
if i + 1 < len(matches):
end = matches[i + 1].start()
# Backtrack to exclude preceding comments of next func
before_next = text[:end]
for line in reversed(before_next.splitlines()):
stripped = line.strip()
if stripped.startswith("//") or not stripped:
end = text.rfind(line, 0, end)
else:
break
else:
end = len(text)
func_name = match.group(1)
chunk_text = text[start:end].rstrip()
chunks.append({
"text": chunk_text,
"metadata": {
"symbol_name": func_name,
},
})
return chunks
def _fixed_chunk(text: str, chunking_cfg: dict) -> list[dict]:
"""Fixed-size fallback for code without recognisable boundaries."""
max_tokens = chunking_cfg.get("max_tokens", 1024)
overlap_tokens = chunking_cfg.get("overlap_tokens", 50)
lines = text.splitlines()
if not lines:
return []
# Approximate tokens as words
chunks = []
current_lines = []
current_tokens = 0
idx = 0
for line in lines:
line_tokens = len(line.split())
if current_tokens + line_tokens > max_tokens and current_lines:
chunks.append({
"text": "\n".join(current_lines),
"chunk_index": idx,
"metadata": {},
})
idx += 1
# Keep some overlap
overlap_lines = []
overlap_count = 0
for l in reversed(current_lines):
l_tokens = len(l.split())
if overlap_count + l_tokens > overlap_tokens:
break
overlap_lines.insert(0, l)
overlap_count += l_tokens
current_lines = overlap_lines
current_tokens = overlap_count
current_lines.append(line)
current_tokens += line_tokens
if current_lines:
chunks.append({
"text": "\n".join(current_lines),
"chunk_index": idx,
"metadata": {},
})
return chunks
+54
View File
@@ -0,0 +1,54 @@
"""File type detection and routing."""
from pathlib import Path
EXTENSION_MAP = {
# Docling-handled formats
".pdf": ("pdf", None),
".docx": ("pdf", None), # Docling handles DOCX too
".html": ("pdf", None),
".htm": ("pdf", None),
".png": ("pdf", None),
".jpg": ("pdf", None),
".jpeg": ("pdf", None),
".tiff": ("pdf", None),
".bmp": ("pdf", None),
".webp": ("pdf", None),
# Markdown / text
".md": ("markdown", None),
".markdown": ("markdown", None),
".txt": ("markdown", None),
# Code
".py": ("code", "python"),
".sh": ("code", "bash"),
".bash": ("code", "bash"),
".go": ("code", "go"),
}
SUPPORTED_EXTENSIONS = set(EXTENSION_MAP.keys())
def detect_type(path: Path, force_type: str | None = None,
force_language: str | None = None) -> tuple[str, str | None]:
"""Detect document type and language from file extension.
Returns (doc_type, language) tuple.
Raises ValueError for unsupported file types.
"""
if force_type:
return force_type, force_language
ext = path.suffix.lower()
if ext not in EXTENSION_MAP:
supported = ", ".join(sorted(SUPPORTED_EXTENSIONS))
raise ValueError(f"Unsupported file type '{ext}'. Supported: {supported}")
doc_type, language = EXTENSION_MAP[ext]
if force_language:
language = force_language
return doc_type, language
def is_supported(path: Path) -> bool:
"""Check if a file has a supported extension."""
return path.suffix.lower() in SUPPORTED_EXTENSIONS
+123
View File
@@ -0,0 +1,123 @@
"""Docling-based ingestion for PDFs, DOCX, HTML, and images."""
import logging
from pathlib import Path
# Suppress noisy Docling/RapidOCR logging
logging.getLogger("RapidOCR").setLevel(logging.ERROR)
logging.getLogger("docling.models.stages.ocr.rapid_ocr_model").setLevel(logging.ERROR)
logging.getLogger("docling").setLevel(logging.WARNING)
def chunk_document(file_path: Path, cfg: dict) -> list[dict]:
"""Ingest a document using Docling and return chunks."""
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions, RapidOcrOptions
# Configure PDF pipeline
ocr_setting = cfg.get("ingestion", {}).get("enable_ocr", "auto")
pdf_opts = PdfPipelineOptions()
if ocr_setting == "never":
pdf_opts.do_ocr = False
elif ocr_setting == "always":
pdf_opts.do_ocr = True
pdf_opts.ocr_options = RapidOcrOptions(force_full_page_ocr=True)
else:
# "auto" — enable OCR but only trigger on pages with significant bitmap content
pdf_opts.do_ocr = True
pdf_opts.ocr_options = RapidOcrOptions(bitmap_area_threshold=0.25)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_opts),
}
)
# Convert
result = converter.convert(str(file_path))
doc = result.document
# Chunk using hierarchy-aware chunker
chunking_cfg = cfg.get("chunking", {}).get("pdf", {})
strategy = chunking_cfg.get("strategy", "hierarchy")
if strategy == "hierarchy":
chunks = _hierarchy_chunk(doc)
else:
chunks = _fixed_chunk(doc, chunking_cfg)
if not chunks:
# Fallback: try extracting raw text
text = doc.export_to_markdown()
if text and text.strip():
chunks = _fixed_chunk_text(text, chunking_cfg)
return chunks
def _hierarchy_chunk(doc) -> list[dict]:
"""Use Docling's HierarchicalChunker."""
from docling_core.transforms.chunker import HierarchicalChunker
chunker = HierarchicalChunker()
chunks = []
for i, chunk in enumerate(chunker.chunk(doc)):
meta = {}
# Extract page info if available
if hasattr(chunk, "meta") and chunk.meta:
if hasattr(chunk.meta, "doc_items"):
for item in chunk.meta.doc_items:
if hasattr(item, "prov") and item.prov:
for prov in item.prov:
if hasattr(prov, "page_no"):
meta["page"] = prov.page_no
break
# Section headers
if hasattr(chunk.meta, "headings") and chunk.meta.headings:
meta["section_header"] = " > ".join(chunk.meta.headings)
chunks.append({
"text": chunk.text,
"chunk_index": i,
"metadata": meta,
})
return chunks
def _fixed_chunk(doc, chunking_cfg: dict) -> list[dict]:
"""Fixed-size chunking from Docling document."""
text = doc.export_to_markdown()
return _fixed_chunk_text(text, chunking_cfg)
def _fixed_chunk_text(text: str, chunking_cfg: dict) -> list[dict]:
"""Fixed-size chunking from plain text."""
max_tokens = chunking_cfg.get("max_tokens", 1024)
overlap = chunking_cfg.get("overlap_tokens", 50)
# Approximate: 1 token ~= 4 chars
max_chars = max_tokens * 4
overlap_chars = overlap * 4
chunks = []
start = 0
idx = 0
while start < len(text):
end = start + max_chars
chunk_text = text[start:end].strip()
if chunk_text:
chunks.append({
"text": chunk_text,
"chunk_index": idx,
"metadata": {},
})
idx += 1
start = end - overlap_chars
return chunks
+210
View File
@@ -0,0 +1,210 @@
"""Markdown ingestion — header-based splitting."""
import re
def chunk_markdown(text: str, cfg: dict) -> list[dict]:
"""Split markdown at header boundaries with hierarchy context."""
chunking_cfg = cfg.get("chunking", {}).get("markdown", {})
strategy = chunking_cfg.get("strategy", "header")
if strategy == "fixed" or not _has_headers(text):
return _fixed_chunk(text, chunking_cfg)
return _header_chunk(text, chunking_cfg)
def _has_headers(text: str) -> bool:
"""Check if text contains markdown headers."""
return bool(re.search(r"^#{1,6}\s+", text, re.MULTILINE))
def _header_chunk(text: str, chunking_cfg: dict) -> list[dict]:
"""Split at ## and ### boundaries with hierarchy context."""
min_tokens = chunking_cfg.get("min_tokens", 50)
max_tokens = chunking_cfg.get("max_tokens", 1024)
sections = _split_at_headers(text)
if not sections:
return _fixed_chunk(text, chunking_cfg)
# Merge small sections
sections = _merge_small_sections(sections, min_tokens)
# Split large sections
chunks = []
for section in sections:
content = section["content"].strip()
if not content:
continue
# Add hierarchy context
if section["header_chain"]:
context = " > ".join(section["header_chain"])
full_text = f"{context}\n\n{content}"
else:
full_text = content
approx_tokens = len(full_text.split())
if approx_tokens > max_tokens:
sub_chunks = _split_large_section(full_text, max_tokens, chunking_cfg)
chunks.extend(sub_chunks)
else:
chunks.append({"text": full_text, "metadata": {
"section_header": section["header_chain"][-1] if section["header_chain"] else None,
}})
# Assign chunk indices
for i, c in enumerate(chunks):
c["chunk_index"] = i
return chunks
def _split_at_headers(text: str) -> list[dict]:
"""Split text into sections at header boundaries."""
header_pattern = re.compile(r"^(#{1,6})\s+(.*?)$", re.MULTILINE)
sections = []
header_stack = [] # Stack of (level, title)
last_end = 0
for match in header_pattern.finditer(text):
# Capture content before this header
if last_end < match.start():
content = text[last_end:match.start()].strip()
if content and sections:
sections[-1]["content"] += "\n\n" + content
elif content:
sections.append({
"header_chain": [],
"content": content,
})
level = len(match.group(1))
title = match.group(2).strip()
# Update header stack
while header_stack and header_stack[-1][0] >= level:
header_stack.pop()
header_stack.append((level, title))
chain = [h[1] for h in header_stack]
sections.append({
"header_chain": chain,
"content": "",
})
last_end = match.end()
# Capture trailing content
if last_end < len(text):
trailing = text[last_end:].strip()
if trailing and sections:
sections[-1]["content"] += "\n\n" + trailing
elif trailing:
sections.append({"header_chain": [], "content": trailing})
return sections
def _merge_small_sections(sections: list[dict], min_tokens: int) -> list[dict]:
"""Merge sections smaller than min_tokens with next section."""
if not sections:
return sections
merged = []
pending = None
for section in sections:
if pending is not None:
# Merge pending into this section
section["content"] = pending["content"] + "\n\n" + section["content"]
if not section["header_chain"] and pending["header_chain"]:
section["header_chain"] = pending["header_chain"]
pending = None
approx_tokens = len(section["content"].split())
if approx_tokens < min_tokens:
pending = section
else:
merged.append(section)
if pending is not None:
if merged:
merged[-1]["content"] += "\n\n" + pending["content"]
else:
merged.append(pending)
return merged
def _split_large_section(text: str, max_tokens: int, chunking_cfg: dict) -> list[dict]:
"""Split a large section at paragraph boundaries with overlap."""
overlap_tokens = chunking_cfg.get("overlap_tokens",
cfg_defaults().get("overlap_tokens", 50))
paragraphs = re.split(r"\n\n+", text)
chunks = []
current_paras = []
current_tokens = 0
for para in paragraphs:
para_tokens = len(para.split())
if current_tokens + para_tokens > max_tokens and current_paras:
chunks.append({"text": "\n\n".join(current_paras), "metadata": {}})
# Keep overlap
overlap_paras = []
overlap_count = 0
for p in reversed(current_paras):
p_tokens = len(p.split())
if overlap_count + p_tokens > overlap_tokens:
break
overlap_paras.insert(0, p)
overlap_count += p_tokens
current_paras = overlap_paras
current_tokens = overlap_count
current_paras.append(para)
current_tokens += para_tokens
if current_paras:
chunks.append({"text": "\n\n".join(current_paras), "metadata": {}})
return chunks
def cfg_defaults():
"""Return default chunking config."""
return {"max_tokens": 1024, "overlap_tokens": 50, "min_tokens": 50}
def _fixed_chunk(text: str, chunking_cfg: dict) -> list[dict]:
"""Fixed-size fallback for plain text without headers."""
max_tokens = chunking_cfg.get("max_tokens", 512)
overlap_tokens = chunking_cfg.get("overlap_tokens", 50)
words = text.split()
if not words:
return []
chunks = []
start = 0
idx = 0
while start < len(words):
end = min(start + max_tokens, len(words))
chunk_text = " ".join(words[start:end]).strip()
if chunk_text:
chunks.append({
"text": chunk_text,
"chunk_index": idx,
"metadata": {},
})
idx += 1
start = end - overlap_tokens
if start >= len(words) or end == len(words):
break
return chunks
+19
View File
@@ -0,0 +1,19 @@
"""Note ingestion — whole-document chunks."""
def chunk_note(text: str) -> list[dict]:
"""Return note text as a single chunk."""
return [{"text": text, "metadata": {}, "chunk_index": 0}]
def auto_title(text: str, max_len: int = 80) -> str:
"""Generate a title from the first line of text, truncated at word boundary."""
first_line = text.strip().split("\n")[0].strip()
if len(first_line) <= max_len:
return first_line
truncated = first_line[:max_len]
# Truncate at last space
last_space = truncated.rfind(" ")
if last_space > 0:
truncated = truncated[:last_space]
return truncated + "..."
+144
View File
@@ -0,0 +1,144 @@
"""Output formatters — JSON and human-readable."""
import json
import sys
def format_search_results(data: dict, fmt: str = "json") -> str:
"""Format search results for output."""
if fmt == "json":
return json.dumps(data, indent=2, ensure_ascii=False)
return _human_search(data)
def _human_search(data: dict) -> str:
"""Human-readable search output."""
lines = []
total = data["total_matches"]
returned = data["returned"]
lines.append(f'Search: "{data["query"]}" ({total} matches, showing top {returned})')
lines.append("")
for i, r in enumerate(data["results"], 1):
src = r["source"]
score = r["score"]
# Title with page/section
location = ""
if src.get("page"):
location = f" (p.{src['page']})"
elif src.get("section_header"):
location = f" \u00a7{src['section_header']}"
# Tags
tag_str = ""
if src.get("tags"):
tag_str = " [" + ", ".join(src["tags"]) + "]"
lines.append(f" {i:2d}. [{score:.3f}] {src['title']}{location} [{src['type']}]{tag_str}")
# Text preview (first 200 chars)
preview = r["text"][:200].replace("\n", " ").strip()
if len(r["text"]) > 200:
preview += "..."
lines.append(f" {preview}")
lines.append("")
return "\n".join(lines)
def format_document_list(docs: list[dict], fmt: str = "json") -> str:
"""Format document list."""
if fmt == "json":
return json.dumps(docs, indent=2, ensure_ascii=False)
return _human_doc_list(docs)
def _human_doc_list(docs: list[dict]) -> str:
"""Human-readable document list."""
if not docs:
return "No documents indexed. Run `kb add` to get started."
lines = [f"{'ID':>5} {'Type':<10} {'Chunks':>6} {'Title':<40} {'Tags'}"]
lines.append("-" * 80)
for d in docs:
tags = ", ".join(d.get("tags", []))
title = d["title"][:40]
lines.append(f"{d['id']:>5} {d['type']:<10} {d['chunk_count']:>6} {title:<40} {tags}")
return "\n".join(lines)
def format_tags(tags: list[dict], fmt: str = "json") -> str:
"""Format tag list."""
if fmt == "json":
return json.dumps(tags, indent=2, ensure_ascii=False)
if not tags:
return "No tags. Use `kb add --tags` or `kb tag` to add tags."
lines = [f"{'Tag':<30} {'Documents':>10}"]
lines.append("-" * 42)
for t in tags:
lines.append(f"{t['name']:<30} {t['count']:>10}")
return "\n".join(lines)
def format_doc_info(info: dict, fmt: str = "json") -> str:
"""Format document info."""
if fmt == "json":
return json.dumps(info, indent=2, ensure_ascii=False)
lines = []
lines.append(f"Document #{info['id']}: {info['title']}")
lines.append(f" Type: {info['type']}")
if info.get("language"):
lines.append(f" Language: {info['language']}")
if info.get("path"):
lines.append(f" Path: {info['path']}")
lines.append(f" Hash: {info['content_hash']}")
lines.append(f" Created: {info['created_at']}")
if info.get("tags"):
lines.append(f" Tags: {', '.join(info['tags'])}")
lines.append(f" Chunks: {info['chunk_count']}")
lines.append("")
for chunk in info.get("chunks", []):
preview = chunk["text"][:100].replace("\n", " ").strip()
if len(chunk["text"]) > 100:
preview += "..."
lines.append(f" [{chunk['chunk_index']}] {preview}")
return "\n".join(lines)
def format_status(status: dict, fmt: str = "json") -> str:
"""Format status output."""
if fmt == "json":
return json.dumps(status, indent=2, ensure_ascii=False)
lines = []
lines.append("Knowledge Base Status")
lines.append("=" * 40)
lines.append(f" Model: {status['model_name']}")
lines.append(f" Embedding dim: {status['embedding_dim']}")
lines.append(f" Schema version: {status['schema_version']}")
lines.append(f" DB size: {_human_size(status['db_size_bytes'])}")
lines.append("")
lines.append(" Documents:")
for dtype, count in status.get("documents", {}).items():
lines.append(f" {dtype:<12} {count:>5}")
lines.append(f" {'total':<12} {status['total_documents']:>5}")
lines.append(f" Total chunks: {status['total_chunks']}")
return "\n".join(lines)
def _human_size(size_bytes: int) -> str:
"""Format bytes as human-readable."""
for unit in ("B", "KB", "MB", "GB"):
if size_bytes < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} TB"
View File
+261
View File
@@ -0,0 +1,261 @@
"""Hybrid search — FTS5 + vector with Reciprocal Rank Fusion."""
import json
import re
import struct
import sqlite3
def hybrid_search(conn: sqlite3.Connection, query: str, model_name: str, cfg: dict,
top: int = 10, tags: list[str] | None = None,
doc_type: str | None = None, fts_only: bool = False,
vec_only: bool = False, threshold: float | None = None) -> dict:
"""Run hybrid search and return merged results."""
candidate_count = top * 3 # Fetch more candidates for RRF
fts_results = {}
vec_results = {}
if not vec_only:
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
if not fts_only:
vec_results = _vector_search(conn, query, model_name, cfg, candidate_count, tags, doc_type)
# Merge via RRF
rrf_k = cfg.get("search", {}).get("rrf_k", 60)
if fts_only:
merged = _single_source_results(fts_results, "fts")
elif vec_only:
merged = _single_source_results(vec_results, "vector")
else:
merged = _rrf_merge(fts_results, vec_results, rrf_k)
# Apply threshold
if threshold is not None:
merged = [r for r in merged if r["score"] >= threshold]
# Sort and limit
merged.sort(key=lambda x: x["score"], reverse=True)
total = len(merged)
merged = merged[:top]
# Enrich with document metadata
results = []
for r in merged:
chunk_id = r["chunk_id"]
row = conn.execute("""
SELECT c.id, c.text, c.chunk_index, c.metadata as chunk_meta,
d.id as doc_id, d.title, d.source_path, d.doc_type,
d.language, d.metadata as doc_meta
FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
""", (chunk_id,)).fetchone()
if not row:
continue
chunk_meta = json.loads(row["chunk_meta"]) if row["chunk_meta"] else {}
# Get tags for this document
tag_rows = conn.execute("""
SELECT t.name FROM tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = ?
ORDER BY t.name
""", (row["doc_id"],)).fetchall()
# Count total chunks for this document
total_chunks = conn.execute(
"SELECT COUNT(*) FROM chunks WHERE document_id = ?", (row["doc_id"],)
).fetchone()[0]
results.append({
"chunk_id": row["id"],
"score": round(r["score"], 6),
"score_breakdown": r["score_breakdown"],
"text": row["text"],
"source": {
"document_id": row["doc_id"],
"title": row["title"],
"path": row["source_path"],
"type": row["doc_type"],
"page": chunk_meta.get("page"),
"section_header": chunk_meta.get("section_header"),
"chunk_index": row["chunk_index"],
"total_chunks": total_chunks,
"tags": [r["name"] for r in tag_rows],
},
})
return {
"query": query,
"results": results,
"total_matches": total,
"returned": len(results),
}
def _fts_search(conn: sqlite3.Connection, query: str, limit: int,
tags: list[str] | None, doc_type: str | None) -> dict[int, float]:
"""Run FTS5 search, return {chunk_id: bm25_score}."""
escaped = _escape_fts_query(query)
if not escaped.strip():
return {}
sql = """
SELECT f.rowid as chunk_id, bm25(chunks_fts) as score
FROM chunks_fts f
"""
joins = []
where = [f"chunks_fts MATCH ?"]
params = [escaped]
if tags or doc_type:
joins.append("JOIN chunks c ON f.rowid = c.id")
joins.append("JOIN documents d ON c.document_id = d.id")
if doc_type:
where.append("d.doc_type = ?")
params.append(doc_type)
if tags:
for i, tag in enumerate(tags):
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
where.append(f"t{i}.name = ?")
params.append(tag.strip().lower())
sql += " " + " ".join(joins)
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY score LIMIT ?"
params.append(limit)
rows = conn.execute(sql, params).fetchall()
# BM25 scores are negative (lower = better), normalise to positive
results = {}
for row in rows:
results[row["chunk_id"]] = -row["score"] # Negate so higher = better
return results
def _vector_search(conn: sqlite3.Connection, query: str, model_name: str,
cfg: dict, limit: int, tags: list[str] | None,
doc_type: str | None) -> dict[int, float]:
"""Run vector similarity search, return {chunk_id: similarity_score}."""
from kb_search.embeddings import embed_texts
prefix = cfg.get("embedding", {}).get("query_prefix", "")
query_emb = embed_texts(model_name, [query], prefix=prefix)[0]
blob = struct.pack(f"{len(query_emb)}f", *query_emb)
# sqlite-vec returns results ordered by distance (lower = more similar)
rows = conn.execute("""
SELECT chunk_id, distance
FROM chunks_vec
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?
""", (blob, limit)).fetchall()
results = {}
for row in rows:
# Convert distance to similarity (1 - distance for cosine)
similarity = max(0, 1 - row["distance"])
chunk_id = row["chunk_id"]
# Apply filters post-hoc for vector search
if tags or doc_type:
check = conn.execute("""
SELECT 1 FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.id = ?
""" + (" AND d.doc_type = ?" if doc_type else ""),
(chunk_id,) + ((doc_type,) if doc_type else ())
).fetchone()
if not check:
continue
if tags:
tag_count = conn.execute("""
SELECT COUNT(*) FROM chunks c
JOIN documents d ON c.document_id = d.id
JOIN document_tags dt ON d.id = dt.document_id
JOIN tags t ON dt.tag_id = t.id
WHERE c.id = ? AND t.name IN ({})
""".format(",".join("?" * len(tags))),
(chunk_id, *[t.strip().lower() for t in tags])
).fetchone()[0]
if tag_count < len(tags):
continue
results[chunk_id] = similarity
return results
def _rrf_merge(fts_results: dict[int, float], vec_results: dict[int, float],
k: int = 60) -> list[dict]:
"""Merge two result sets using Reciprocal Rank Fusion."""
# Rank each result set
fts_ranked = _rank_results(fts_results)
vec_ranked = _rank_results(vec_results)
all_ids = set(fts_ranked.keys()) | set(vec_ranked.keys())
merged = []
for chunk_id in all_ids:
fts_rank = fts_ranked.get(chunk_id)
vec_rank = vec_ranked.get(chunk_id)
score = 0
if fts_rank is not None:
score += 1 / (k + fts_rank)
if vec_rank is not None:
score += 1 / (k + vec_rank)
fts_score = round(1 / (k + fts_rank), 6) if fts_rank is not None else None
vec_score = round(1 / (k + vec_rank), 6) if vec_rank is not None else None
merged.append({
"chunk_id": chunk_id,
"score": score,
"score_breakdown": {"fts": fts_score, "vector": vec_score},
})
return merged
def _single_source_results(results: dict[int, float], source: str) -> list[dict]:
"""Convert single-source results to merged format."""
ranked = _rank_results(results)
merged = []
for chunk_id, rank in ranked.items():
score = results[chunk_id]
breakdown = {"fts": None, "vector": None}
breakdown[source] = round(score, 6)
merged.append({
"chunk_id": chunk_id,
"score": score,
"score_breakdown": breakdown,
})
return merged
def _rank_results(results: dict[int, float]) -> dict[int, int]:
"""Rank results by score (1-indexed, higher score = lower rank number)."""
sorted_ids = sorted(results.keys(), key=lambda x: results[x], reverse=True)
return {chunk_id: rank + 1 for rank, chunk_id in enumerate(sorted_ids)}
def _escape_fts_query(query: str) -> str:
"""Escape special FTS5 characters in a query."""
# Remove FTS5 operators that could cause syntax errors
query = re.sub(r'["\(\)\*\:\^]', " ", query)
# Collapse multiple spaces
query = re.sub(r"\s+", " ", query).strip()
return query