Initial MVP
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
"""kb-search: CLI knowledge base with hybrid search."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,616 @@
|
||||
"""CLI entry point for kb-search."""
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(package_name="kb-search")
|
||||
def main():
|
||||
"""Personal knowledge base with hybrid search."""
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--model", default=None, help="Embedding model name (HuggingFace).")
|
||||
@click.option("--status", is_flag=True, help="Show initialisation status.")
|
||||
def init(model, status):
|
||||
"""Initialise the knowledge base and download models."""
|
||||
from kb_search.config import get_data_dir, get_db_path, load_config
|
||||
from kb_search.database import get_connection, get_db_config, init_schema, run_migrations, set_db_config
|
||||
from kb_search.embeddings import download_model, get_model_dim
|
||||
|
||||
cfg = load_config()
|
||||
data_dir = get_data_dir(cfg)
|
||||
db_path = get_db_path(cfg)
|
||||
model_name = model or cfg["embedding"]["model"]
|
||||
|
||||
if status:
|
||||
click.echo(f"Data directory: {data_dir} ({'exists' if data_dir.exists() else 'not created'})")
|
||||
click.echo(f"Database: {db_path} ({'exists' if db_path.exists() else 'not created'})")
|
||||
if db_path.exists():
|
||||
conn = get_connection(db_path)
|
||||
db_model = get_db_config(conn, "model_name", "not set")
|
||||
db_dim = get_db_config(conn, "embedding_dim", "not set")
|
||||
click.echo(f"Model: {db_model} ({db_dim} dim)")
|
||||
conn.close()
|
||||
else:
|
||||
click.echo(f"Model: {model_name} (not yet initialised)")
|
||||
return
|
||||
|
||||
# Create data directory
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download model and get dimension
|
||||
download_model(model_name)
|
||||
dim = get_model_dim(model_name)
|
||||
|
||||
# Initialise database
|
||||
conn = get_connection(db_path)
|
||||
init_schema(conn, embedding_dim=dim)
|
||||
run_migrations(conn)
|
||||
set_db_config(conn, "model_name", model_name)
|
||||
set_db_config(conn, "embedding_dim", str(dim))
|
||||
conn.close()
|
||||
|
||||
click.echo(f"Knowledge base initialised at {data_dir}")
|
||||
click.echo(f"Model: {model_name} ({dim} dimensions)")
|
||||
click.echo("Ready! Add documents with `kb add`.")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("path", required=False)
|
||||
@click.option("--note", default=None, help="Add an inline text note.")
|
||||
@click.option("--title", default=None, help="Title for the note.")
|
||||
@click.option("--tags", default=None, help="Comma-separated tags.")
|
||||
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Force document type.")
|
||||
@click.option("--language", default=None, type=click.Choice(["python", "bash", "go"]), help="Force code language.")
|
||||
@click.option("--recursive", is_flag=True, help="Recurse into directories.")
|
||||
@click.option("--workers", default=None, type=int, help="Number of parallel workers.")
|
||||
def add(path, note, title, tags, doc_type, language, recursive, workers):
|
||||
"""Add documents to the knowledge base."""
|
||||
import hashlib
|
||||
from pathlib import Path as P
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import (
|
||||
get_connection, hash_exists, insert_chunk, insert_document,
|
||||
insert_embedding, tag_document,
|
||||
)
|
||||
from kb_search.embeddings import check_model_binding, embed_texts
|
||||
from kb_search.ingest.detector import detect_type, is_supported
|
||||
from kb_search.ingest.note import auto_title, chunk_note
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
check_model_binding(conn, cfg)
|
||||
model_name = cfg["embedding"]["model"]
|
||||
tag_list = [t.strip() for t in tags.split(",")] if tags else []
|
||||
|
||||
if note:
|
||||
# Inline note
|
||||
content_hash = hashlib.sha256(note.encode()).hexdigest()
|
||||
if hash_exists(conn, content_hash):
|
||||
click.echo("Skipped: note (already indexed)")
|
||||
conn.close()
|
||||
return
|
||||
note_title = title or auto_title(note)
|
||||
chunks = chunk_note(note)
|
||||
doc_id = insert_document(conn, note_title, None, content_hash, "note")
|
||||
for c in chunks:
|
||||
chunk_id = insert_chunk(conn, doc_id, c["chunk_index"], c["text"], metadata=c["metadata"])
|
||||
emb = embed_texts(model_name, [c["text"]], prefix=cfg["embedding"].get("passage_prefix", ""))
|
||||
insert_embedding(conn, chunk_id, emb[0])
|
||||
if tag_list:
|
||||
tag_document(conn, doc_id, tag_list)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
click.echo(f"Added note: {note_title}")
|
||||
return
|
||||
|
||||
if not path:
|
||||
raise click.ClickException("Provide a file/directory path or use --note.")
|
||||
|
||||
file_path = P(path).expanduser().resolve()
|
||||
|
||||
if file_path.is_dir():
|
||||
_add_directory(conn, file_path, cfg, model_name, tag_list, doc_type, language,
|
||||
recursive, workers)
|
||||
elif file_path.is_file():
|
||||
result = _add_single_file(conn, file_path, cfg, model_name, tag_list, doc_type, language)
|
||||
click.echo(result)
|
||||
else:
|
||||
raise click.ClickException(f"Path not found: {file_path}")
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def _add_single_file(conn, file_path, cfg, model_name, tag_list, force_type, force_language):
|
||||
"""Add a single file. Returns a status message."""
|
||||
import hashlib
|
||||
from kb_search.database import (
|
||||
hash_exists, insert_chunk, insert_document, insert_embedding, tag_document,
|
||||
)
|
||||
from kb_search.embeddings import embed_texts
|
||||
from kb_search.ingest.detector import detect_type
|
||||
|
||||
# Dedup check
|
||||
content_hash = hashlib.sha256(file_path.read_bytes()).hexdigest()
|
||||
if hash_exists(conn, content_hash):
|
||||
return f"Skipped: {file_path.name} (already indexed)"
|
||||
|
||||
doc_type, language = detect_type(file_path, force_type, force_language)
|
||||
chunks = _get_chunks(file_path, doc_type, language, cfg)
|
||||
|
||||
if not chunks:
|
||||
return f"Skipped: {file_path.name} (no content extracted)"
|
||||
|
||||
title = file_path.stem
|
||||
doc_id = insert_document(conn, title, str(file_path), content_hash, doc_type,
|
||||
language=language)
|
||||
|
||||
# Embed all chunks in one batch
|
||||
texts = [c["text"] for c in chunks]
|
||||
prefix = cfg["embedding"].get("passage_prefix", "")
|
||||
embeddings = embed_texts(model_name, texts, prefix=prefix)
|
||||
|
||||
for c, emb in zip(chunks, embeddings):
|
||||
chunk_id = insert_chunk(conn, doc_id, c["chunk_index"], c["text"],
|
||||
token_count=c.get("token_count"),
|
||||
metadata=c.get("metadata", {}))
|
||||
insert_embedding(conn, chunk_id, emb)
|
||||
|
||||
if tag_list:
|
||||
tag_document(conn, doc_id, tag_list)
|
||||
|
||||
conn.commit()
|
||||
return f"Added: {file_path.name} ({len(chunks)} chunks)"
|
||||
|
||||
|
||||
def _get_chunks(file_path, doc_type, language, cfg):
|
||||
"""Route to the correct chunking pipeline."""
|
||||
if doc_type == "pdf":
|
||||
from kb_search.ingest.docling import chunk_document
|
||||
return chunk_document(file_path, cfg)
|
||||
elif doc_type == "markdown":
|
||||
from kb_search.ingest.markdown import chunk_markdown
|
||||
text = file_path.read_text(errors="replace")
|
||||
return chunk_markdown(text, cfg)
|
||||
elif doc_type == "code":
|
||||
from kb_search.ingest.code import chunk_code
|
||||
text = file_path.read_text(errors="replace")
|
||||
return chunk_code(text, language, cfg)
|
||||
elif doc_type == "note":
|
||||
from kb_search.ingest.note import chunk_note
|
||||
text = file_path.read_text(errors="replace")
|
||||
return chunk_note(text)
|
||||
return []
|
||||
|
||||
|
||||
def _add_directory(conn, dir_path, cfg, model_name, tag_list, force_type, force_language,
|
||||
recursive, workers):
|
||||
"""Add all supported files in a directory."""
|
||||
from pathlib import Path as P
|
||||
from kb_search.ingest.detector import is_supported
|
||||
|
||||
pattern = "**/*" if recursive else "*"
|
||||
files = sorted(f for f in dir_path.glob(pattern) if f.is_file() and is_supported(f))
|
||||
|
||||
if not files:
|
||||
click.echo(f"No supported files found in {dir_path}")
|
||||
return
|
||||
|
||||
added = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
error_log = cfg.get("data_dir", "~/.kb")
|
||||
from kb_search.config import get_data_dir
|
||||
error_log_path = get_data_dir(cfg) / "ingest-errors.log"
|
||||
|
||||
with click.progressbar(files, label="Ingesting", show_pos=True) as bar:
|
||||
for f in bar:
|
||||
try:
|
||||
result = _add_single_file(conn, f, cfg, model_name, tag_list,
|
||||
force_type, force_language)
|
||||
if "Skipped" in result:
|
||||
skipped += 1
|
||||
else:
|
||||
added += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
with open(error_log_path, "a") as log:
|
||||
log.write(f"{f}: {e}\n")
|
||||
|
||||
click.echo(f"\nAdded {added} documents. {failed} failed. {skipped} skipped (already indexed).")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("query")
|
||||
@click.option("--top", default=None, type=int, help="Number of results.")
|
||||
@click.option("--tags", default=None, help="Filter by tags (comma-separated).")
|
||||
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Filter by document type.")
|
||||
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
|
||||
@click.option("--fts-only", is_flag=True, help="Full-text search only.")
|
||||
@click.option("--vec-only", is_flag=True, help="Vector search only.")
|
||||
@click.option("--threshold", default=None, type=float, help="Minimum score cutoff.")
|
||||
def search(query, top, tags, doc_type, fmt, fts_only, vec_only, threshold):
|
||||
"""Search the knowledge base."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection, get_db_config
|
||||
from kb_search.embeddings import check_model_binding
|
||||
from kb_search.search import hybrid_search
|
||||
from kb_search.output import format_search_results
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
check_model_binding(conn, cfg)
|
||||
|
||||
model_name = get_db_config(conn, "model_name") or cfg["embedding"]["model"]
|
||||
top = top or cfg["search"]["default_top"]
|
||||
fmt = fmt or cfg["search"]["default_format"]
|
||||
tag_list = [t.strip() for t in tags.split(",")] if tags else None
|
||||
|
||||
results = hybrid_search(
|
||||
conn, query, model_name, cfg,
|
||||
top=top, tags=tag_list, doc_type=doc_type,
|
||||
fts_only=fts_only, vec_only=vec_only, threshold=threshold,
|
||||
)
|
||||
conn.close()
|
||||
|
||||
click.echo(format_search_results(results, fmt))
|
||||
|
||||
|
||||
@main.command("list")
|
||||
@click.option("--type", "doc_type", default=None, type=click.Choice(["pdf", "markdown", "code", "note"]), help="Filter by document type.")
|
||||
@click.option("--tags", default=None, help="Filter by tags (comma-separated).")
|
||||
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
|
||||
def list_docs(doc_type, tags, fmt):
|
||||
"""List indexed documents."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection
|
||||
from kb_search.output import format_document_list
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
fmt = fmt or cfg["search"]["default_format"]
|
||||
|
||||
sql = """
|
||||
SELECT d.id, d.title, d.doc_type as type, d.created_at,
|
||||
COUNT(c.id) as chunk_count
|
||||
FROM documents d
|
||||
LEFT JOIN chunks c ON d.id = c.document_id
|
||||
"""
|
||||
joins = []
|
||||
where = []
|
||||
params = []
|
||||
|
||||
if doc_type:
|
||||
where.append("d.doc_type = ?")
|
||||
params.append(doc_type)
|
||||
|
||||
tag_list = [t.strip().lower() for t in tags.split(",")] if tags else []
|
||||
for i, tag in enumerate(tag_list):
|
||||
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
||||
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
||||
where.append(f"t{i}.name = ?")
|
||||
params.append(tag)
|
||||
|
||||
sql += " " + " ".join(joins)
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " GROUP BY d.id ORDER BY d.created_at DESC"
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
docs = []
|
||||
for row in rows:
|
||||
tag_rows = conn.execute("""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""", (row["id"],)).fetchall()
|
||||
docs.append({
|
||||
"id": row["id"],
|
||||
"title": row["title"],
|
||||
"type": row["type"],
|
||||
"tags": [r["name"] for r in tag_rows],
|
||||
"chunk_count": row["chunk_count"],
|
||||
"created_at": row["created_at"],
|
||||
})
|
||||
|
||||
conn.close()
|
||||
click.echo(format_document_list(docs, fmt))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("doc_id", type=int)
|
||||
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
|
||||
def info(doc_id, fmt):
|
||||
"""Show document details."""
|
||||
import json as jsonlib
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection
|
||||
from kb_search.output import format_doc_info
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
fmt = fmt or cfg["search"]["default_format"]
|
||||
|
||||
row = conn.execute("SELECT * FROM documents WHERE id = ?", (doc_id,)).fetchone()
|
||||
if not row:
|
||||
raise click.ClickException(f"Document not found: {doc_id}")
|
||||
|
||||
chunks = conn.execute(
|
||||
"SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index",
|
||||
(doc_id,),
|
||||
).fetchall()
|
||||
|
||||
tag_rows = conn.execute("""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""", (doc_id,)).fetchall()
|
||||
|
||||
info_data = {
|
||||
"id": row["id"],
|
||||
"title": row["title"],
|
||||
"type": row["doc_type"],
|
||||
"language": row["language"],
|
||||
"path": row["source_path"],
|
||||
"content_hash": row["content_hash"],
|
||||
"created_at": row["created_at"],
|
||||
"tags": [r["name"] for r in tag_rows],
|
||||
"chunk_count": len(chunks),
|
||||
"chunks": [{"chunk_index": c["chunk_index"], "text": c["text"]} for c in chunks],
|
||||
}
|
||||
|
||||
conn.close()
|
||||
click.echo(format_doc_info(info_data, fmt))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("doc_id", type=int)
|
||||
@click.option("--yes", is_flag=True, help="Skip confirmation prompt.")
|
||||
def remove(doc_id, yes):
|
||||
"""Remove a document from the knowledge base."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
row = conn.execute("SELECT id, title FROM documents WHERE id = ?", (doc_id,)).fetchone()
|
||||
if not row:
|
||||
raise click.ClickException(f"Document not found: {doc_id}")
|
||||
|
||||
chunk_count = conn.execute(
|
||||
"SELECT COUNT(*) FROM chunks WHERE document_id = ?", (doc_id,)
|
||||
).fetchone()[0]
|
||||
|
||||
if not yes:
|
||||
if not click.confirm(f"Remove '{row['title']}' and its {chunk_count} chunks?"):
|
||||
click.echo("Cancelled.")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
# Delete vectors for this document's chunks
|
||||
conn.execute("""
|
||||
DELETE FROM chunks_vec WHERE chunk_id IN (
|
||||
SELECT id FROM chunks WHERE document_id = ?
|
||||
)
|
||||
""", (doc_id,))
|
||||
# Cascade handles chunks, document_tags
|
||||
conn.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
click.echo(f"Removed '{row['title']}' ({chunk_count} chunks).")
|
||||
|
||||
|
||||
@main.command("tags")
|
||||
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
|
||||
def list_tags(fmt):
|
||||
"""List all tags with document counts."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection
|
||||
from kb_search.output import format_tags
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
fmt = fmt or cfg["search"]["default_format"]
|
||||
|
||||
rows = conn.execute("""
|
||||
SELECT t.name, COUNT(dt.document_id) as count
|
||||
FROM tags t
|
||||
LEFT JOIN document_tags dt ON t.id = dt.tag_id
|
||||
GROUP BY t.id
|
||||
ORDER BY count DESC, t.name
|
||||
""").fetchall()
|
||||
|
||||
tags = [{"name": r["name"], "count": r["count"]} for r in rows]
|
||||
conn.close()
|
||||
click.echo(format_tags(tags, fmt))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("doc_id", type=int)
|
||||
@click.option("--add", "add_tags", default=None, help="Tags to add (comma-separated).")
|
||||
@click.option("--remove", "remove_tags", default=None, help="Tags to remove (comma-separated).")
|
||||
def tag(doc_id, add_tags, remove_tags):
|
||||
"""Manage tags on a document."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection, tag_document, untag_document
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
row = conn.execute("SELECT id, title FROM documents WHERE id = ?", (doc_id,)).fetchone()
|
||||
if not row:
|
||||
raise click.ClickException(f"Document not found: {doc_id}")
|
||||
|
||||
if add_tags:
|
||||
tags = [t.strip() for t in add_tags.split(",")]
|
||||
tag_document(conn, doc_id, tags)
|
||||
conn.commit()
|
||||
click.echo(f"Added tags [{', '.join(tags)}] to '{row['title']}'")
|
||||
|
||||
if remove_tags:
|
||||
tags = [t.strip() for t in remove_tags.split(",")]
|
||||
untag_document(conn, doc_id, tags)
|
||||
conn.commit()
|
||||
click.echo(f"Removed tags [{', '.join(tags)}] from '{row['title']}'")
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--format", "fmt", default=None, type=click.Choice(["json", "human"]), help="Output format.")
|
||||
def status(fmt):
|
||||
"""Show knowledge base status and statistics."""
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import get_connection, get_db_config
|
||||
from kb_search.output import format_status
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
fmt = fmt or cfg["search"]["default_format"]
|
||||
|
||||
doc_counts = {}
|
||||
for row in conn.execute("SELECT doc_type, COUNT(*) as cnt FROM documents GROUP BY doc_type").fetchall():
|
||||
doc_counts[row["doc_type"]] = row["cnt"]
|
||||
|
||||
total_docs = sum(doc_counts.values())
|
||||
total_chunks = conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
db_size = db_path.stat().st_size
|
||||
|
||||
status_data = {
|
||||
"model_name": get_db_config(conn, "model_name", "not set"),
|
||||
"embedding_dim": get_db_config(conn, "embedding_dim", "not set"),
|
||||
"schema_version": get_db_config(conn, "schema_version", "not set"),
|
||||
"db_size_bytes": db_size,
|
||||
"documents": doc_counts,
|
||||
"total_documents": total_docs,
|
||||
"total_chunks": total_chunks,
|
||||
}
|
||||
|
||||
conn.close()
|
||||
click.echo(format_status(status_data, fmt))
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--model", default=None, help="Switch to a different embedding model.")
|
||||
def reindex(model):
|
||||
"""Re-embed all chunks (optionally with a new model)."""
|
||||
import struct
|
||||
from kb_search.config import get_db_path, load_config
|
||||
from kb_search.database import (
|
||||
get_connection, get_db_config, insert_embedding,
|
||||
recreate_vec_table, set_db_config,
|
||||
)
|
||||
from kb_search.embeddings import download_model, embed_texts, get_model_dim
|
||||
|
||||
cfg = load_config()
|
||||
db_path = get_db_path(cfg)
|
||||
if not db_path.exists():
|
||||
raise click.ClickException("Knowledge base not initialised. Run `kb init` first.")
|
||||
|
||||
conn = get_connection(db_path)
|
||||
model_name = model or get_db_config(conn, "model_name") or cfg["embedding"]["model"]
|
||||
|
||||
# Download model if switching
|
||||
if model:
|
||||
download_model(model_name)
|
||||
|
||||
dim = get_model_dim(model_name)
|
||||
|
||||
# Get all chunks
|
||||
rows = conn.execute("SELECT id, text FROM chunks ORDER BY id").fetchall()
|
||||
if not rows:
|
||||
click.echo("No chunks to re-embed.")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
click.echo(f"Re-embedding {len(rows)} chunks with '{model_name}' ({dim} dim)...")
|
||||
|
||||
# Embed in batches
|
||||
batch_size = 256
|
||||
all_ids = [r["id"] for r in rows]
|
||||
all_texts = [r["text"] for r in rows]
|
||||
|
||||
prefix = cfg["embedding"].get("passage_prefix", "")
|
||||
all_embeddings = []
|
||||
with click.progressbar(range(0, len(all_texts), batch_size), label="Embedding") as bar:
|
||||
for i in bar:
|
||||
batch = all_texts[i:i + batch_size]
|
||||
batch_embs = embed_texts(model_name, batch, prefix=prefix)
|
||||
all_embeddings.extend(batch_embs)
|
||||
|
||||
# Atomically replace vectors
|
||||
recreate_vec_table(conn, dim)
|
||||
for chunk_id, emb in zip(all_ids, all_embeddings):
|
||||
insert_embedding(conn, chunk_id, emb)
|
||||
|
||||
set_db_config(conn, "model_name", model_name)
|
||||
set_db_config(conn, "embedding_dim", str(dim))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
click.echo(f"Reindex complete. {len(rows)} chunks embedded with '{model_name}'.")
|
||||
|
||||
|
||||
@main.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def config(ctx):
|
||||
"""View or modify configuration."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
from kb_search.config import config_with_sources
|
||||
|
||||
entries = config_with_sources()
|
||||
max_key = max(len(k) for k, _, _ in entries)
|
||||
max_val = max(len(v) for _, v, _ in entries)
|
||||
for key, value, source in entries:
|
||||
click.echo(f" {key:<{max_key}} {value:<{max_val}} ({source})")
|
||||
|
||||
|
||||
@config.command("set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key, value):
|
||||
"""Set a configuration value."""
|
||||
from kb_search.config import get_config_path, load_config, save_config_value
|
||||
|
||||
cfg = load_config()
|
||||
path = get_config_path(cfg)
|
||||
save_config_value(path, key, value)
|
||||
click.echo(f"Set {key} = {value} in {path}")
|
||||
|
||||
|
||||
main.add_command(config)
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Configuration loading with YAML + ENV + defaults."""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
DEFAULTS = {
|
||||
"data_dir": "~/.kb",
|
||||
"embedding": {
|
||||
"model": "all-MiniLM-L6-v2",
|
||||
"query_prefix": "",
|
||||
"passage_prefix": "",
|
||||
},
|
||||
"search": {
|
||||
"default_top": 10,
|
||||
"default_format": "json",
|
||||
"rrf_k": 60,
|
||||
},
|
||||
"chunking": {
|
||||
"defaults": {
|
||||
"max_tokens": 512,
|
||||
"overlap_tokens": 50,
|
||||
},
|
||||
"pdf": {
|
||||
"strategy": "hierarchy",
|
||||
"max_tokens": 1024,
|
||||
},
|
||||
"markdown": {
|
||||
"strategy": "header",
|
||||
"min_tokens": 50,
|
||||
"max_tokens": 1024,
|
||||
},
|
||||
"code": {
|
||||
"strategy": "ast",
|
||||
"include_context": True,
|
||||
"max_tokens": 1024,
|
||||
},
|
||||
"note": {
|
||||
"strategy": "whole",
|
||||
},
|
||||
},
|
||||
"ingestion": {
|
||||
"workers": 4,
|
||||
"batch_size": 50,
|
||||
"enable_ocr": "auto",
|
||||
},
|
||||
}
|
||||
|
||||
# ENV variable mapping: ENV_NAME -> config dotted key
|
||||
ENV_MAP = {
|
||||
"KB_DATA_DIR": "data_dir",
|
||||
"KB_MODEL": "embedding.model",
|
||||
"KB_DEFAULT_TOP": "search.default_top",
|
||||
"KB_DEFAULT_FORMAT": "search.default_format",
|
||||
}
|
||||
|
||||
# Type coercions for ENV values
|
||||
ENV_TYPES = {
|
||||
"search.default_top": int,
|
||||
}
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
"""Deep merge override into base, returning a new dict."""
|
||||
result = deepcopy(base)
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = deepcopy(value)
|
||||
return result
|
||||
|
||||
|
||||
def _set_nested(d: dict, dotted_key: str, value):
|
||||
"""Set a value in a nested dict using a dotted key path."""
|
||||
keys = dotted_key.split(".")
|
||||
for key in keys[:-1]:
|
||||
d = d.setdefault(key, {})
|
||||
d[keys[-1]] = value
|
||||
|
||||
|
||||
def _get_nested(d: dict, dotted_key: str, default=None):
|
||||
"""Get a value from a nested dict using a dotted key path."""
|
||||
keys = dotted_key.split(".")
|
||||
for key in keys:
|
||||
if not isinstance(d, dict) or key not in d:
|
||||
return default
|
||||
d = d[key]
|
||||
return d
|
||||
|
||||
|
||||
def get_data_dir(cfg: dict) -> Path:
|
||||
"""Resolve the data directory from config."""
|
||||
return Path(cfg["data_dir"]).expanduser()
|
||||
|
||||
|
||||
def get_config_path(cfg: dict) -> Path:
|
||||
"""Path to the YAML config file."""
|
||||
return get_data_dir(cfg) / "config.yaml"
|
||||
|
||||
|
||||
def get_db_path(cfg: dict) -> Path:
|
||||
"""Path to the SQLite database."""
|
||||
return get_data_dir(cfg) / "kb.db"
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> dict:
|
||||
"""Load config with precedence: ENV > YAML > defaults.
|
||||
|
||||
CLI flags are applied by the caller after this returns.
|
||||
"""
|
||||
cfg = deepcopy(DEFAULTS)
|
||||
|
||||
# Determine config file path (ENV can override data_dir which affects path)
|
||||
if config_path is None:
|
||||
data_dir = os.environ.get("KB_DATA_DIR", DEFAULTS["data_dir"])
|
||||
config_path = Path(data_dir).expanduser() / "config.yaml"
|
||||
|
||||
# Load YAML if it exists
|
||||
if config_path.is_file():
|
||||
with open(config_path) as f:
|
||||
yaml_cfg = yaml.safe_load(f) or {}
|
||||
cfg = _deep_merge(cfg, yaml_cfg)
|
||||
|
||||
# Apply ENV overrides
|
||||
for env_name, dotted_key in ENV_MAP.items():
|
||||
env_val = os.environ.get(env_name)
|
||||
if env_val is not None:
|
||||
coerce = ENV_TYPES.get(dotted_key, str)
|
||||
_set_nested(cfg, dotted_key, coerce(env_val))
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def save_config_value(config_path: Path, dotted_key: str, value: str):
|
||||
"""Set a single value in the YAML config file."""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
existing = {}
|
||||
if config_path.is_file():
|
||||
with open(config_path) as f:
|
||||
existing = yaml.safe_load(f) or {}
|
||||
|
||||
# Try numeric coercion
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
if value.lower() in ("true", "false"):
|
||||
value = value.lower() == "true"
|
||||
|
||||
_set_nested(existing, dotted_key, value)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(existing, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def config_with_sources(config_path: Path | None = None) -> list[tuple[str, str, str]]:
|
||||
"""Return a flat list of (dotted_key, value, source) tuples for display."""
|
||||
if config_path is None:
|
||||
data_dir = os.environ.get("KB_DATA_DIR", DEFAULTS["data_dir"])
|
||||
config_path = Path(data_dir).expanduser() / "config.yaml"
|
||||
|
||||
yaml_cfg = {}
|
||||
if config_path.is_file():
|
||||
with open(config_path) as f:
|
||||
yaml_cfg = yaml.safe_load(f) or {}
|
||||
|
||||
# Build reverse ENV map for source detection
|
||||
env_keys = {v: k for k, v in ENV_MAP.items()}
|
||||
|
||||
def _flatten(d, prefix=""):
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
key = f"{prefix}.{k}" if prefix else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(_flatten(v, key))
|
||||
else:
|
||||
# Determine source
|
||||
env_name = env_keys.get(key)
|
||||
if env_name and os.environ.get(env_name) is not None:
|
||||
source = f"env ({env_name})"
|
||||
elif _get_nested(yaml_cfg, key) is not None:
|
||||
source = "config.yaml"
|
||||
else:
|
||||
source = "default"
|
||||
items.append((key, str(v), source))
|
||||
return items
|
||||
|
||||
cfg = load_config(config_path)
|
||||
return _flatten(cfg)
|
||||
@@ -0,0 +1,229 @@
|
||||
"""SQLite database management with FTS5 and sqlite-vec."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import sqlite_vec
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def get_connection(db_path: Path) -> sqlite3.Connection:
|
||||
"""Open a SQLite connection with sqlite-vec loaded."""
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def init_schema(conn: sqlite3.Connection, embedding_dim: int):
|
||||
"""Create all tables, FTS, vector index, and triggers."""
|
||||
conn.executescript(f"""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL,
|
||||
source_path TEXT,
|
||||
content_hash TEXT NOT NULL,
|
||||
doc_type TEXT NOT NULL CHECK(doc_type IN ('pdf','markdown','code','note')),
|
||||
language TEXT,
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
metadata TEXT DEFAULT '{{}}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
token_count INTEGER,
|
||||
metadata TEXT DEFAULT '{{}}',
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS document_tags (
|
||||
document_id INTEGER REFERENCES documents(id) ON DELETE CASCADE,
|
||||
tag_id INTEGER REFERENCES tags(id) ON DELETE CASCADE,
|
||||
PRIMARY KEY (document_id, tag_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_content_hash ON documents(content_hash);
|
||||
""")
|
||||
|
||||
# FTS5 virtual table (content-sync with chunks)
|
||||
conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
content='chunks',
|
||||
content_rowid='id',
|
||||
tokenize='porter unicode61'
|
||||
)
|
||||
""")
|
||||
|
||||
# FTS sync triggers
|
||||
conn.executescript("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES('delete', old.id, old.text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(chunks_fts, rowid, text) VALUES('delete', old.id, old.text);
|
||||
INSERT INTO chunks_fts(rowid, text) VALUES (new.id, new.text);
|
||||
END;
|
||||
""")
|
||||
|
||||
# Vector table
|
||||
conn.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_vec USING vec0(
|
||||
chunk_id INTEGER PRIMARY KEY,
|
||||
embedding FLOAT[{embedding_dim}]
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_db_config(conn: sqlite3.Connection, key: str, default: str | None = None) -> str | None:
|
||||
"""Get a value from the config table."""
|
||||
row = conn.execute("SELECT value FROM config WHERE key = ?", (key,)).fetchone()
|
||||
return row["value"] if row else default
|
||||
|
||||
|
||||
def set_db_config(conn: sqlite3.Connection, key: str, value: str):
|
||||
"""Set a value in the config table."""
|
||||
conn.execute(
|
||||
"INSERT INTO config (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
|
||||
(key, value, value),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def check_schema_version(conn: sqlite3.Connection) -> int | None:
|
||||
"""Check the current schema version. Returns None if not initialised."""
|
||||
try:
|
||||
return int(get_db_config(conn, "schema_version", "0"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection):
|
||||
"""Run pending schema migrations."""
|
||||
current = check_schema_version(conn) or 0
|
||||
|
||||
# Migration registry: version -> callable
|
||||
migrations: dict[int, callable] = {
|
||||
# Future migrations go here:
|
||||
# 2: _migrate_v2,
|
||||
}
|
||||
|
||||
for version in sorted(migrations.keys()):
|
||||
if current < version:
|
||||
migrations[version](conn)
|
||||
set_db_config(conn, "schema_version", str(version))
|
||||
|
||||
if current < SCHEMA_VERSION:
|
||||
set_db_config(conn, "schema_version", str(SCHEMA_VERSION))
|
||||
|
||||
|
||||
def recreate_vec_table(conn: sqlite3.Connection, embedding_dim: int):
|
||||
"""Drop and recreate the vector table with a new dimension."""
|
||||
conn.execute("DROP TABLE IF EXISTS chunks_vec")
|
||||
conn.execute(f"""
|
||||
CREATE VIRTUAL TABLE chunks_vec USING vec0(
|
||||
chunk_id INTEGER PRIMARY KEY,
|
||||
embedding FLOAT[{embedding_dim}]
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def insert_document(conn: sqlite3.Connection, title: str, source_path: str | None,
|
||||
content_hash: str, doc_type: str, language: str | None = None,
|
||||
metadata: dict | None = None) -> int:
|
||||
"""Insert a document and return its ID."""
|
||||
cur = conn.execute(
|
||||
"INSERT INTO documents (title, source_path, content_hash, doc_type, language, metadata) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(title, source_path, content_hash, doc_type, language, json.dumps(metadata or {})),
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def insert_chunk(conn: sqlite3.Connection, document_id: int, chunk_index: int,
|
||||
text: str, token_count: int | None = None,
|
||||
metadata: dict | None = None) -> int:
|
||||
"""Insert a chunk and return its ID."""
|
||||
cur = conn.execute(
|
||||
"INSERT INTO chunks (document_id, chunk_index, text, token_count, metadata) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(document_id, chunk_index, text, token_count, json.dumps(metadata or {})),
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def insert_embedding(conn: sqlite3.Connection, chunk_id: int, embedding: list[float]):
|
||||
"""Insert a chunk embedding into the vector table."""
|
||||
import struct
|
||||
blob = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
conn.execute(
|
||||
"INSERT INTO chunks_vec (chunk_id, embedding) VALUES (?, ?)",
|
||||
(chunk_id, blob),
|
||||
)
|
||||
|
||||
|
||||
def hash_exists(conn: sqlite3.Connection, content_hash: str) -> bool:
|
||||
"""Check if a document with this content hash already exists."""
|
||||
row = conn.execute(
|
||||
"SELECT 1 FROM documents WHERE content_hash = ? LIMIT 1", (content_hash,)
|
||||
).fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
def get_or_create_tag(conn: sqlite3.Connection, name: str) -> int:
|
||||
"""Get or create a tag, return its ID. Tags are stored lowercase."""
|
||||
name = name.strip().lower()
|
||||
row = conn.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()
|
||||
if row:
|
||||
return row["id"]
|
||||
cur = conn.execute("INSERT INTO tags (name) VALUES (?)", (name,))
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def tag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]):
|
||||
"""Associate tags with a document."""
|
||||
for name in tag_names:
|
||||
tag_id = get_or_create_tag(conn, name)
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO document_tags (document_id, tag_id) VALUES (?, ?)",
|
||||
(document_id, tag_id),
|
||||
)
|
||||
|
||||
|
||||
def untag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[str]):
|
||||
"""Remove tag associations from a document."""
|
||||
for name in tag_names:
|
||||
name = name.strip().lower()
|
||||
conn.execute(
|
||||
"DELETE FROM document_tags WHERE document_id = ? AND tag_id = "
|
||||
"(SELECT id FROM tags WHERE name = ?)",
|
||||
(document_id, name),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Embedding model management — download, load, and inference via ONNX."""
|
||||
|
||||
import click
|
||||
from pathlib import Path
|
||||
|
||||
_model_instance = None
|
||||
_model_name = None
|
||||
|
||||
|
||||
def load_model(model_name: str):
|
||||
"""Load a sentence-transformers model with ONNX backend. Caches in-process."""
|
||||
global _model_instance, _model_name
|
||||
if _model_instance is not None and _model_name == model_name:
|
||||
return _model_instance
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
click.echo(f"Loading model '{model_name}'...")
|
||||
try:
|
||||
_model_instance = SentenceTransformer(model_name, backend="onnx")
|
||||
except Exception:
|
||||
# Fallback: some models may not have pre-exported ONNX. Let sentence-transformers export.
|
||||
click.echo("Optimising model for ONNX inference (one-time)...")
|
||||
_model_instance = SentenceTransformer(model_name, backend="onnx")
|
||||
|
||||
_model_name = model_name
|
||||
return _model_instance
|
||||
|
||||
|
||||
def get_model_dim(model_name: str) -> int:
|
||||
"""Get the embedding dimension for a model."""
|
||||
model = load_model(model_name)
|
||||
return model.get_sentence_embedding_dimension()
|
||||
|
||||
|
||||
def embed_texts(model_name: str, texts: list[str],
|
||||
prefix: str = "", show_progress: bool = False) -> list[list[float]]:
|
||||
"""Embed a list of texts, returning float vectors."""
|
||||
model = load_model(model_name)
|
||||
if prefix:
|
||||
texts = [prefix + t for t in texts]
|
||||
embeddings = model.encode(texts, show_progress_bar=show_progress, convert_to_numpy=True)
|
||||
return [e.tolist() for e in embeddings]
|
||||
|
||||
|
||||
def download_model(model_name: str):
|
||||
"""Pre-download a model (for kb init)."""
|
||||
click.echo(f"Downloading embedding model '{model_name}'...")
|
||||
load_model(model_name)
|
||||
click.echo("Embedding model ready.")
|
||||
|
||||
|
||||
def check_model_binding(conn, cfg: dict):
|
||||
"""Verify the loaded model matches what the DB expects. Raises on mismatch."""
|
||||
from kb_search.database import get_db_config
|
||||
|
||||
db_model = get_db_config(conn, "model_name")
|
||||
if db_model is None:
|
||||
return # Not yet initialised
|
||||
|
||||
config_model = cfg["embedding"]["model"]
|
||||
if db_model != config_model:
|
||||
db_dim = get_db_config(conn, "embedding_dim", "?")
|
||||
raise click.ClickException(
|
||||
f"Model mismatch: DB uses '{db_model}' ({db_dim} dim) but config specifies "
|
||||
f"'{config_model}'. Run `kb reindex --model {config_model}` to switch models."
|
||||
)
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Code ingestion — AST/regex-based splitting for Python, Bash, Go."""
|
||||
|
||||
import ast
|
||||
import re
|
||||
|
||||
|
||||
def chunk_code(text: str, language: str | None, cfg: dict) -> list[dict]:
|
||||
"""Split code at function/class boundaries."""
|
||||
chunking_cfg = cfg.get("chunking", {}).get("code", {})
|
||||
strategy = chunking_cfg.get("strategy", "ast")
|
||||
include_context = chunking_cfg.get("include_context", True)
|
||||
|
||||
if strategy == "fixed":
|
||||
return _fixed_chunk(text, chunking_cfg)
|
||||
|
||||
if language == "python":
|
||||
chunks = _chunk_python(text, include_context)
|
||||
elif language in ("bash", "sh"):
|
||||
chunks = _chunk_bash(text, include_context)
|
||||
elif language == "go":
|
||||
chunks = _chunk_go(text, include_context)
|
||||
else:
|
||||
chunks = []
|
||||
|
||||
if not chunks:
|
||||
return _fixed_chunk(text, chunking_cfg)
|
||||
|
||||
for i, c in enumerate(chunks):
|
||||
c["chunk_index"] = i
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_python(text: str, include_context: bool) -> list[dict]:
|
||||
"""Split Python using stdlib ast module."""
|
||||
try:
|
||||
tree = ast.parse(text)
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
lines = text.splitlines(keepends=True)
|
||||
chunks = []
|
||||
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
class_lines = _get_node_source(lines, node)
|
||||
class_docstring = ast.get_docstring(node) or ""
|
||||
|
||||
# Each method becomes a chunk
|
||||
methods = [n for n in ast.iter_child_nodes(node) if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
|
||||
|
||||
if methods:
|
||||
for method in methods:
|
||||
method_src = _get_node_source(lines, method)
|
||||
if include_context and class_docstring:
|
||||
context = f"class {node.name}:\n \"\"\"{class_docstring}\"\"\"\n\n"
|
||||
chunk_text = context + method_src
|
||||
elif include_context:
|
||||
chunk_text = f"class {node.name}:\n\n" + method_src
|
||||
else:
|
||||
chunk_text = method_src
|
||||
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"symbol_name": f"{node.name}.{method.name}",
|
||||
"line_start": method.lineno,
|
||||
"line_end": method.end_lineno,
|
||||
},
|
||||
})
|
||||
else:
|
||||
# Class with no methods — single chunk
|
||||
chunks.append({
|
||||
"text": class_lines,
|
||||
"metadata": {
|
||||
"symbol_name": node.name,
|
||||
"line_start": node.lineno,
|
||||
"line_end": node.end_lineno,
|
||||
},
|
||||
})
|
||||
|
||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
func_src = _get_node_source(lines, node)
|
||||
chunks.append({
|
||||
"text": func_src,
|
||||
"metadata": {
|
||||
"symbol_name": node.name,
|
||||
"line_start": node.lineno,
|
||||
"line_end": node.end_lineno,
|
||||
},
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_node_source(lines: list[str], node) -> str:
|
||||
"""Extract source code for an AST node, including decorators."""
|
||||
start = node.lineno - 1
|
||||
# Include decorators
|
||||
if hasattr(node, "decorator_list") and node.decorator_list:
|
||||
start = node.decorator_list[0].lineno - 1
|
||||
end = node.end_lineno
|
||||
return "".join(lines[start:end]).rstrip()
|
||||
|
||||
|
||||
def _chunk_bash(text: str, include_context: bool) -> list[dict]:
|
||||
"""Split Bash at function boundaries using regex."""
|
||||
# Match: function name() { or name() {
|
||||
func_pattern = re.compile(
|
||||
r"^((?:#[^\n]*\n)*)?" # Optional preceding comment block
|
||||
r"(?:function\s+(\w+)\s*\(\s*\)\s*\{|(\w+)\s*\(\s*\)\s*\{)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
matches = list(func_pattern.finditer(text))
|
||||
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
start = match.start()
|
||||
# Find end: next function or end of file
|
||||
if i + 1 < len(matches):
|
||||
end = matches[i + 1].start()
|
||||
else:
|
||||
end = len(text)
|
||||
|
||||
func_name = match.group(2) or match.group(3)
|
||||
chunk_text = text[start:end].rstrip()
|
||||
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"symbol_name": func_name,
|
||||
},
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_go(text: str, include_context: bool) -> list[dict]:
|
||||
"""Split Go at func declarations using regex."""
|
||||
func_pattern = re.compile(
|
||||
r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
matches = list(func_pattern.finditer(text))
|
||||
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
start = match.start()
|
||||
# Include preceding comment block
|
||||
before = text[:start]
|
||||
comment_lines = []
|
||||
for line in reversed(before.splitlines()):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("//") or not stripped:
|
||||
comment_lines.insert(0, line)
|
||||
else:
|
||||
break
|
||||
if comment_lines:
|
||||
comment_text = "\n".join(comment_lines).strip()
|
||||
if comment_text:
|
||||
start = text.rfind(comment_lines[0], 0, start)
|
||||
|
||||
# Find end: next func or end of file
|
||||
if i + 1 < len(matches):
|
||||
end = matches[i + 1].start()
|
||||
# Backtrack to exclude preceding comments of next func
|
||||
before_next = text[:end]
|
||||
for line in reversed(before_next.splitlines()):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("//") or not stripped:
|
||||
end = text.rfind(line, 0, end)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
end = len(text)
|
||||
|
||||
func_name = match.group(1)
|
||||
chunk_text = text[start:end].rstrip()
|
||||
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"symbol_name": func_name,
|
||||
},
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _fixed_chunk(text: str, chunking_cfg: dict) -> list[dict]:
|
||||
"""Fixed-size fallback for code without recognisable boundaries."""
|
||||
max_tokens = chunking_cfg.get("max_tokens", 1024)
|
||||
overlap_tokens = chunking_cfg.get("overlap_tokens", 50)
|
||||
|
||||
lines = text.splitlines()
|
||||
if not lines:
|
||||
return []
|
||||
|
||||
# Approximate tokens as words
|
||||
chunks = []
|
||||
current_lines = []
|
||||
current_tokens = 0
|
||||
idx = 0
|
||||
|
||||
for line in lines:
|
||||
line_tokens = len(line.split())
|
||||
if current_tokens + line_tokens > max_tokens and current_lines:
|
||||
chunks.append({
|
||||
"text": "\n".join(current_lines),
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
idx += 1
|
||||
# Keep some overlap
|
||||
overlap_lines = []
|
||||
overlap_count = 0
|
||||
for l in reversed(current_lines):
|
||||
l_tokens = len(l.split())
|
||||
if overlap_count + l_tokens > overlap_tokens:
|
||||
break
|
||||
overlap_lines.insert(0, l)
|
||||
overlap_count += l_tokens
|
||||
current_lines = overlap_lines
|
||||
current_tokens = overlap_count
|
||||
|
||||
current_lines.append(line)
|
||||
current_tokens += line_tokens
|
||||
|
||||
if current_lines:
|
||||
chunks.append({
|
||||
"text": "\n".join(current_lines),
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,54 @@
|
||||
"""File type detection and routing."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
EXTENSION_MAP = {
|
||||
# Docling-handled formats
|
||||
".pdf": ("pdf", None),
|
||||
".docx": ("pdf", None), # Docling handles DOCX too
|
||||
".html": ("pdf", None),
|
||||
".htm": ("pdf", None),
|
||||
".png": ("pdf", None),
|
||||
".jpg": ("pdf", None),
|
||||
".jpeg": ("pdf", None),
|
||||
".tiff": ("pdf", None),
|
||||
".bmp": ("pdf", None),
|
||||
".webp": ("pdf", None),
|
||||
# Markdown / text
|
||||
".md": ("markdown", None),
|
||||
".markdown": ("markdown", None),
|
||||
".txt": ("markdown", None),
|
||||
# Code
|
||||
".py": ("code", "python"),
|
||||
".sh": ("code", "bash"),
|
||||
".bash": ("code", "bash"),
|
||||
".go": ("code", "go"),
|
||||
}
|
||||
|
||||
SUPPORTED_EXTENSIONS = set(EXTENSION_MAP.keys())
|
||||
|
||||
|
||||
def detect_type(path: Path, force_type: str | None = None,
|
||||
force_language: str | None = None) -> tuple[str, str | None]:
|
||||
"""Detect document type and language from file extension.
|
||||
|
||||
Returns (doc_type, language) tuple.
|
||||
Raises ValueError for unsupported file types.
|
||||
"""
|
||||
if force_type:
|
||||
return force_type, force_language
|
||||
|
||||
ext = path.suffix.lower()
|
||||
if ext not in EXTENSION_MAP:
|
||||
supported = ", ".join(sorted(SUPPORTED_EXTENSIONS))
|
||||
raise ValueError(f"Unsupported file type '{ext}'. Supported: {supported}")
|
||||
|
||||
doc_type, language = EXTENSION_MAP[ext]
|
||||
if force_language:
|
||||
language = force_language
|
||||
return doc_type, language
|
||||
|
||||
|
||||
def is_supported(path: Path) -> bool:
|
||||
"""Check if a file has a supported extension."""
|
||||
return path.suffix.lower() in SUPPORTED_EXTENSIONS
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Docling-based ingestion for PDFs, DOCX, HTML, and images."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress noisy Docling/RapidOCR logging
|
||||
logging.getLogger("RapidOCR").setLevel(logging.ERROR)
|
||||
logging.getLogger("docling.models.stages.ocr.rapid_ocr_model").setLevel(logging.ERROR)
|
||||
logging.getLogger("docling").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def chunk_document(file_path: Path, cfg: dict) -> list[dict]:
|
||||
"""Ingest a document using Docling and return chunks."""
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions, RapidOcrOptions
|
||||
|
||||
# Configure PDF pipeline
|
||||
ocr_setting = cfg.get("ingestion", {}).get("enable_ocr", "auto")
|
||||
pdf_opts = PdfPipelineOptions()
|
||||
|
||||
if ocr_setting == "never":
|
||||
pdf_opts.do_ocr = False
|
||||
elif ocr_setting == "always":
|
||||
pdf_opts.do_ocr = True
|
||||
pdf_opts.ocr_options = RapidOcrOptions(force_full_page_ocr=True)
|
||||
else:
|
||||
# "auto" — enable OCR but only trigger on pages with significant bitmap content
|
||||
pdf_opts.do_ocr = True
|
||||
pdf_opts.ocr_options = RapidOcrOptions(bitmap_area_threshold=0.25)
|
||||
|
||||
converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_opts),
|
||||
}
|
||||
)
|
||||
|
||||
# Convert
|
||||
result = converter.convert(str(file_path))
|
||||
doc = result.document
|
||||
|
||||
# Chunk using hierarchy-aware chunker
|
||||
chunking_cfg = cfg.get("chunking", {}).get("pdf", {})
|
||||
strategy = chunking_cfg.get("strategy", "hierarchy")
|
||||
|
||||
if strategy == "hierarchy":
|
||||
chunks = _hierarchy_chunk(doc)
|
||||
else:
|
||||
chunks = _fixed_chunk(doc, chunking_cfg)
|
||||
|
||||
if not chunks:
|
||||
# Fallback: try extracting raw text
|
||||
text = doc.export_to_markdown()
|
||||
if text and text.strip():
|
||||
chunks = _fixed_chunk_text(text, chunking_cfg)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _hierarchy_chunk(doc) -> list[dict]:
|
||||
"""Use Docling's HierarchicalChunker."""
|
||||
from docling_core.transforms.chunker import HierarchicalChunker
|
||||
|
||||
chunker = HierarchicalChunker()
|
||||
chunks = []
|
||||
|
||||
for i, chunk in enumerate(chunker.chunk(doc)):
|
||||
meta = {}
|
||||
|
||||
# Extract page info if available
|
||||
if hasattr(chunk, "meta") and chunk.meta:
|
||||
if hasattr(chunk.meta, "doc_items"):
|
||||
for item in chunk.meta.doc_items:
|
||||
if hasattr(item, "prov") and item.prov:
|
||||
for prov in item.prov:
|
||||
if hasattr(prov, "page_no"):
|
||||
meta["page"] = prov.page_no
|
||||
break
|
||||
|
||||
# Section headers
|
||||
if hasattr(chunk.meta, "headings") and chunk.meta.headings:
|
||||
meta["section_header"] = " > ".join(chunk.meta.headings)
|
||||
|
||||
chunks.append({
|
||||
"text": chunk.text,
|
||||
"chunk_index": i,
|
||||
"metadata": meta,
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _fixed_chunk(doc, chunking_cfg: dict) -> list[dict]:
|
||||
"""Fixed-size chunking from Docling document."""
|
||||
text = doc.export_to_markdown()
|
||||
return _fixed_chunk_text(text, chunking_cfg)
|
||||
|
||||
|
||||
def _fixed_chunk_text(text: str, chunking_cfg: dict) -> list[dict]:
|
||||
"""Fixed-size chunking from plain text."""
|
||||
max_tokens = chunking_cfg.get("max_tokens", 1024)
|
||||
overlap = chunking_cfg.get("overlap_tokens", 50)
|
||||
|
||||
# Approximate: 1 token ~= 4 chars
|
||||
max_chars = max_tokens * 4
|
||||
overlap_chars = overlap * 4
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
idx = 0
|
||||
while start < len(text):
|
||||
end = start + max_chars
|
||||
chunk_text = text[start:end].strip()
|
||||
if chunk_text:
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
idx += 1
|
||||
start = end - overlap_chars
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Markdown ingestion — header-based splitting."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def chunk_markdown(text: str, cfg: dict) -> list[dict]:
|
||||
"""Split markdown at header boundaries with hierarchy context."""
|
||||
chunking_cfg = cfg.get("chunking", {}).get("markdown", {})
|
||||
strategy = chunking_cfg.get("strategy", "header")
|
||||
|
||||
if strategy == "fixed" or not _has_headers(text):
|
||||
return _fixed_chunk(text, chunking_cfg)
|
||||
|
||||
return _header_chunk(text, chunking_cfg)
|
||||
|
||||
|
||||
def _has_headers(text: str) -> bool:
|
||||
"""Check if text contains markdown headers."""
|
||||
return bool(re.search(r"^#{1,6}\s+", text, re.MULTILINE))
|
||||
|
||||
|
||||
def _header_chunk(text: str, chunking_cfg: dict) -> list[dict]:
|
||||
"""Split at ## and ### boundaries with hierarchy context."""
|
||||
min_tokens = chunking_cfg.get("min_tokens", 50)
|
||||
max_tokens = chunking_cfg.get("max_tokens", 1024)
|
||||
|
||||
sections = _split_at_headers(text)
|
||||
|
||||
if not sections:
|
||||
return _fixed_chunk(text, chunking_cfg)
|
||||
|
||||
# Merge small sections
|
||||
sections = _merge_small_sections(sections, min_tokens)
|
||||
|
||||
# Split large sections
|
||||
chunks = []
|
||||
for section in sections:
|
||||
content = section["content"].strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Add hierarchy context
|
||||
if section["header_chain"]:
|
||||
context = " > ".join(section["header_chain"])
|
||||
full_text = f"{context}\n\n{content}"
|
||||
else:
|
||||
full_text = content
|
||||
|
||||
approx_tokens = len(full_text.split())
|
||||
if approx_tokens > max_tokens:
|
||||
sub_chunks = _split_large_section(full_text, max_tokens, chunking_cfg)
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
chunks.append({"text": full_text, "metadata": {
|
||||
"section_header": section["header_chain"][-1] if section["header_chain"] else None,
|
||||
}})
|
||||
|
||||
# Assign chunk indices
|
||||
for i, c in enumerate(chunks):
|
||||
c["chunk_index"] = i
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _split_at_headers(text: str) -> list[dict]:
|
||||
"""Split text into sections at header boundaries."""
|
||||
header_pattern = re.compile(r"^(#{1,6})\s+(.*?)$", re.MULTILINE)
|
||||
|
||||
sections = []
|
||||
header_stack = [] # Stack of (level, title)
|
||||
last_end = 0
|
||||
|
||||
for match in header_pattern.finditer(text):
|
||||
# Capture content before this header
|
||||
if last_end < match.start():
|
||||
content = text[last_end:match.start()].strip()
|
||||
if content and sections:
|
||||
sections[-1]["content"] += "\n\n" + content
|
||||
elif content:
|
||||
sections.append({
|
||||
"header_chain": [],
|
||||
"content": content,
|
||||
})
|
||||
|
||||
level = len(match.group(1))
|
||||
title = match.group(2).strip()
|
||||
|
||||
# Update header stack
|
||||
while header_stack and header_stack[-1][0] >= level:
|
||||
header_stack.pop()
|
||||
header_stack.append((level, title))
|
||||
|
||||
chain = [h[1] for h in header_stack]
|
||||
|
||||
sections.append({
|
||||
"header_chain": chain,
|
||||
"content": "",
|
||||
})
|
||||
last_end = match.end()
|
||||
|
||||
# Capture trailing content
|
||||
if last_end < len(text):
|
||||
trailing = text[last_end:].strip()
|
||||
if trailing and sections:
|
||||
sections[-1]["content"] += "\n\n" + trailing
|
||||
elif trailing:
|
||||
sections.append({"header_chain": [], "content": trailing})
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _merge_small_sections(sections: list[dict], min_tokens: int) -> list[dict]:
|
||||
"""Merge sections smaller than min_tokens with next section."""
|
||||
if not sections:
|
||||
return sections
|
||||
|
||||
merged = []
|
||||
pending = None
|
||||
|
||||
for section in sections:
|
||||
if pending is not None:
|
||||
# Merge pending into this section
|
||||
section["content"] = pending["content"] + "\n\n" + section["content"]
|
||||
if not section["header_chain"] and pending["header_chain"]:
|
||||
section["header_chain"] = pending["header_chain"]
|
||||
pending = None
|
||||
|
||||
approx_tokens = len(section["content"].split())
|
||||
if approx_tokens < min_tokens:
|
||||
pending = section
|
||||
else:
|
||||
merged.append(section)
|
||||
|
||||
if pending is not None:
|
||||
if merged:
|
||||
merged[-1]["content"] += "\n\n" + pending["content"]
|
||||
else:
|
||||
merged.append(pending)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _split_large_section(text: str, max_tokens: int, chunking_cfg: dict) -> list[dict]:
|
||||
"""Split a large section at paragraph boundaries with overlap."""
|
||||
overlap_tokens = chunking_cfg.get("overlap_tokens",
|
||||
cfg_defaults().get("overlap_tokens", 50))
|
||||
paragraphs = re.split(r"\n\n+", text)
|
||||
|
||||
chunks = []
|
||||
current_paras = []
|
||||
current_tokens = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para_tokens = len(para.split())
|
||||
if current_tokens + para_tokens > max_tokens and current_paras:
|
||||
chunks.append({"text": "\n\n".join(current_paras), "metadata": {}})
|
||||
# Keep overlap
|
||||
overlap_paras = []
|
||||
overlap_count = 0
|
||||
for p in reversed(current_paras):
|
||||
p_tokens = len(p.split())
|
||||
if overlap_count + p_tokens > overlap_tokens:
|
||||
break
|
||||
overlap_paras.insert(0, p)
|
||||
overlap_count += p_tokens
|
||||
current_paras = overlap_paras
|
||||
current_tokens = overlap_count
|
||||
|
||||
current_paras.append(para)
|
||||
current_tokens += para_tokens
|
||||
|
||||
if current_paras:
|
||||
chunks.append({"text": "\n\n".join(current_paras), "metadata": {}})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def cfg_defaults():
|
||||
"""Return default chunking config."""
|
||||
return {"max_tokens": 1024, "overlap_tokens": 50, "min_tokens": 50}
|
||||
|
||||
|
||||
def _fixed_chunk(text: str, chunking_cfg: dict) -> list[dict]:
|
||||
"""Fixed-size fallback for plain text without headers."""
|
||||
max_tokens = chunking_cfg.get("max_tokens", 512)
|
||||
overlap_tokens = chunking_cfg.get("overlap_tokens", 50)
|
||||
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
idx = 0
|
||||
|
||||
while start < len(words):
|
||||
end = min(start + max_tokens, len(words))
|
||||
chunk_text = " ".join(words[start:end]).strip()
|
||||
if chunk_text:
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"chunk_index": idx,
|
||||
"metadata": {},
|
||||
})
|
||||
idx += 1
|
||||
start = end - overlap_tokens
|
||||
if start >= len(words) or end == len(words):
|
||||
break
|
||||
|
||||
return chunks
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Note ingestion — whole-document chunks."""
|
||||
|
||||
|
||||
def chunk_note(text: str) -> list[dict]:
|
||||
"""Return note text as a single chunk."""
|
||||
return [{"text": text, "metadata": {}, "chunk_index": 0}]
|
||||
|
||||
|
||||
def auto_title(text: str, max_len: int = 80) -> str:
|
||||
"""Generate a title from the first line of text, truncated at word boundary."""
|
||||
first_line = text.strip().split("\n")[0].strip()
|
||||
if len(first_line) <= max_len:
|
||||
return first_line
|
||||
truncated = first_line[:max_len]
|
||||
# Truncate at last space
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > 0:
|
||||
truncated = truncated[:last_space]
|
||||
return truncated + "..."
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Output formatters — JSON and human-readable."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def format_search_results(data: dict, fmt: str = "json") -> str:
|
||||
"""Format search results for output."""
|
||||
if fmt == "json":
|
||||
return json.dumps(data, indent=2, ensure_ascii=False)
|
||||
return _human_search(data)
|
||||
|
||||
|
||||
def _human_search(data: dict) -> str:
|
||||
"""Human-readable search output."""
|
||||
lines = []
|
||||
total = data["total_matches"]
|
||||
returned = data["returned"]
|
||||
lines.append(f'Search: "{data["query"]}" ({total} matches, showing top {returned})')
|
||||
lines.append("")
|
||||
|
||||
for i, r in enumerate(data["results"], 1):
|
||||
src = r["source"]
|
||||
score = r["score"]
|
||||
|
||||
# Title with page/section
|
||||
location = ""
|
||||
if src.get("page"):
|
||||
location = f" (p.{src['page']})"
|
||||
elif src.get("section_header"):
|
||||
location = f" \u00a7{src['section_header']}"
|
||||
|
||||
# Tags
|
||||
tag_str = ""
|
||||
if src.get("tags"):
|
||||
tag_str = " [" + ", ".join(src["tags"]) + "]"
|
||||
|
||||
lines.append(f" {i:2d}. [{score:.3f}] {src['title']}{location} [{src['type']}]{tag_str}")
|
||||
|
||||
# Text preview (first 200 chars)
|
||||
preview = r["text"][:200].replace("\n", " ").strip()
|
||||
if len(r["text"]) > 200:
|
||||
preview += "..."
|
||||
lines.append(f" {preview}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_document_list(docs: list[dict], fmt: str = "json") -> str:
|
||||
"""Format document list."""
|
||||
if fmt == "json":
|
||||
return json.dumps(docs, indent=2, ensure_ascii=False)
|
||||
return _human_doc_list(docs)
|
||||
|
||||
|
||||
def _human_doc_list(docs: list[dict]) -> str:
|
||||
"""Human-readable document list."""
|
||||
if not docs:
|
||||
return "No documents indexed. Run `kb add` to get started."
|
||||
|
||||
lines = [f"{'ID':>5} {'Type':<10} {'Chunks':>6} {'Title':<40} {'Tags'}"]
|
||||
lines.append("-" * 80)
|
||||
|
||||
for d in docs:
|
||||
tags = ", ".join(d.get("tags", []))
|
||||
title = d["title"][:40]
|
||||
lines.append(f"{d['id']:>5} {d['type']:<10} {d['chunk_count']:>6} {title:<40} {tags}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_tags(tags: list[dict], fmt: str = "json") -> str:
|
||||
"""Format tag list."""
|
||||
if fmt == "json":
|
||||
return json.dumps(tags, indent=2, ensure_ascii=False)
|
||||
|
||||
if not tags:
|
||||
return "No tags. Use `kb add --tags` or `kb tag` to add tags."
|
||||
|
||||
lines = [f"{'Tag':<30} {'Documents':>10}"]
|
||||
lines.append("-" * 42)
|
||||
for t in tags:
|
||||
lines.append(f"{t['name']:<30} {t['count']:>10}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_doc_info(info: dict, fmt: str = "json") -> str:
|
||||
"""Format document info."""
|
||||
if fmt == "json":
|
||||
return json.dumps(info, indent=2, ensure_ascii=False)
|
||||
|
||||
lines = []
|
||||
lines.append(f"Document #{info['id']}: {info['title']}")
|
||||
lines.append(f" Type: {info['type']}")
|
||||
if info.get("language"):
|
||||
lines.append(f" Language: {info['language']}")
|
||||
if info.get("path"):
|
||||
lines.append(f" Path: {info['path']}")
|
||||
lines.append(f" Hash: {info['content_hash']}")
|
||||
lines.append(f" Created: {info['created_at']}")
|
||||
if info.get("tags"):
|
||||
lines.append(f" Tags: {', '.join(info['tags'])}")
|
||||
lines.append(f" Chunks: {info['chunk_count']}")
|
||||
lines.append("")
|
||||
|
||||
for chunk in info.get("chunks", []):
|
||||
preview = chunk["text"][:100].replace("\n", " ").strip()
|
||||
if len(chunk["text"]) > 100:
|
||||
preview += "..."
|
||||
lines.append(f" [{chunk['chunk_index']}] {preview}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_status(status: dict, fmt: str = "json") -> str:
|
||||
"""Format status output."""
|
||||
if fmt == "json":
|
||||
return json.dumps(status, indent=2, ensure_ascii=False)
|
||||
|
||||
lines = []
|
||||
lines.append("Knowledge Base Status")
|
||||
lines.append("=" * 40)
|
||||
lines.append(f" Model: {status['model_name']}")
|
||||
lines.append(f" Embedding dim: {status['embedding_dim']}")
|
||||
lines.append(f" Schema version: {status['schema_version']}")
|
||||
lines.append(f" DB size: {_human_size(status['db_size_bytes'])}")
|
||||
lines.append("")
|
||||
lines.append(" Documents:")
|
||||
for dtype, count in status.get("documents", {}).items():
|
||||
lines.append(f" {dtype:<12} {count:>5}")
|
||||
lines.append(f" {'total':<12} {status['total_documents']:>5}")
|
||||
lines.append(f" Total chunks: {status['total_chunks']}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _human_size(size_bytes: int) -> str:
|
||||
"""Format bytes as human-readable."""
|
||||
for unit in ("B", "KB", "MB", "GB"):
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024
|
||||
return f"{size_bytes:.1f} TB"
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Hybrid search — FTS5 + vector with Reciprocal Rank Fusion."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import struct
|
||||
import sqlite3
|
||||
|
||||
|
||||
def hybrid_search(conn: sqlite3.Connection, query: str, model_name: str, cfg: dict,
|
||||
top: int = 10, tags: list[str] | None = None,
|
||||
doc_type: str | None = None, fts_only: bool = False,
|
||||
vec_only: bool = False, threshold: float | None = None) -> dict:
|
||||
"""Run hybrid search and return merged results."""
|
||||
candidate_count = top * 3 # Fetch more candidates for RRF
|
||||
|
||||
fts_results = {}
|
||||
vec_results = {}
|
||||
|
||||
if not vec_only:
|
||||
fts_results = _fts_search(conn, query, candidate_count, tags, doc_type)
|
||||
|
||||
if not fts_only:
|
||||
vec_results = _vector_search(conn, query, model_name, cfg, candidate_count, tags, doc_type)
|
||||
|
||||
# Merge via RRF
|
||||
rrf_k = cfg.get("search", {}).get("rrf_k", 60)
|
||||
|
||||
if fts_only:
|
||||
merged = _single_source_results(fts_results, "fts")
|
||||
elif vec_only:
|
||||
merged = _single_source_results(vec_results, "vector")
|
||||
else:
|
||||
merged = _rrf_merge(fts_results, vec_results, rrf_k)
|
||||
|
||||
# Apply threshold
|
||||
if threshold is not None:
|
||||
merged = [r for r in merged if r["score"] >= threshold]
|
||||
|
||||
# Sort and limit
|
||||
merged.sort(key=lambda x: x["score"], reverse=True)
|
||||
total = len(merged)
|
||||
merged = merged[:top]
|
||||
|
||||
# Enrich with document metadata
|
||||
results = []
|
||||
for r in merged:
|
||||
chunk_id = r["chunk_id"]
|
||||
row = conn.execute("""
|
||||
SELECT c.id, c.text, c.chunk_index, c.metadata as chunk_meta,
|
||||
d.id as doc_id, d.title, d.source_path, d.doc_type,
|
||||
d.language, d.metadata as doc_meta
|
||||
FROM chunks c
|
||||
JOIN documents d ON c.document_id = d.id
|
||||
WHERE c.id = ?
|
||||
""", (chunk_id,)).fetchone()
|
||||
|
||||
if not row:
|
||||
continue
|
||||
|
||||
chunk_meta = json.loads(row["chunk_meta"]) if row["chunk_meta"] else {}
|
||||
|
||||
# Get tags for this document
|
||||
tag_rows = conn.execute("""
|
||||
SELECT t.name FROM tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = ?
|
||||
ORDER BY t.name
|
||||
""", (row["doc_id"],)).fetchall()
|
||||
|
||||
# Count total chunks for this document
|
||||
total_chunks = conn.execute(
|
||||
"SELECT COUNT(*) FROM chunks WHERE document_id = ?", (row["doc_id"],)
|
||||
).fetchone()[0]
|
||||
|
||||
results.append({
|
||||
"chunk_id": row["id"],
|
||||
"score": round(r["score"], 6),
|
||||
"score_breakdown": r["score_breakdown"],
|
||||
"text": row["text"],
|
||||
"source": {
|
||||
"document_id": row["doc_id"],
|
||||
"title": row["title"],
|
||||
"path": row["source_path"],
|
||||
"type": row["doc_type"],
|
||||
"page": chunk_meta.get("page"),
|
||||
"section_header": chunk_meta.get("section_header"),
|
||||
"chunk_index": row["chunk_index"],
|
||||
"total_chunks": total_chunks,
|
||||
"tags": [r["name"] for r in tag_rows],
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"results": results,
|
||||
"total_matches": total,
|
||||
"returned": len(results),
|
||||
}
|
||||
|
||||
|
||||
def _fts_search(conn: sqlite3.Connection, query: str, limit: int,
|
||||
tags: list[str] | None, doc_type: str | None) -> dict[int, float]:
|
||||
"""Run FTS5 search, return {chunk_id: bm25_score}."""
|
||||
escaped = _escape_fts_query(query)
|
||||
if not escaped.strip():
|
||||
return {}
|
||||
|
||||
sql = """
|
||||
SELECT f.rowid as chunk_id, bm25(chunks_fts) as score
|
||||
FROM chunks_fts f
|
||||
"""
|
||||
joins = []
|
||||
where = [f"chunks_fts MATCH ?"]
|
||||
params = [escaped]
|
||||
|
||||
if tags or doc_type:
|
||||
joins.append("JOIN chunks c ON f.rowid = c.id")
|
||||
joins.append("JOIN documents d ON c.document_id = d.id")
|
||||
|
||||
if doc_type:
|
||||
where.append("d.doc_type = ?")
|
||||
params.append(doc_type)
|
||||
|
||||
if tags:
|
||||
for i, tag in enumerate(tags):
|
||||
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
||||
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
||||
where.append(f"t{i}.name = ?")
|
||||
params.append(tag.strip().lower())
|
||||
|
||||
sql += " " + " ".join(joins)
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " ORDER BY score LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
# BM25 scores are negative (lower = better), normalise to positive
|
||||
results = {}
|
||||
for row in rows:
|
||||
results[row["chunk_id"]] = -row["score"] # Negate so higher = better
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _vector_search(conn: sqlite3.Connection, query: str, model_name: str,
|
||||
cfg: dict, limit: int, tags: list[str] | None,
|
||||
doc_type: str | None) -> dict[int, float]:
|
||||
"""Run vector similarity search, return {chunk_id: similarity_score}."""
|
||||
from kb_search.embeddings import embed_texts
|
||||
|
||||
prefix = cfg.get("embedding", {}).get("query_prefix", "")
|
||||
query_emb = embed_texts(model_name, [query], prefix=prefix)[0]
|
||||
blob = struct.pack(f"{len(query_emb)}f", *query_emb)
|
||||
|
||||
# sqlite-vec returns results ordered by distance (lower = more similar)
|
||||
rows = conn.execute("""
|
||||
SELECT chunk_id, distance
|
||||
FROM chunks_vec
|
||||
WHERE embedding MATCH ?
|
||||
ORDER BY distance
|
||||
LIMIT ?
|
||||
""", (blob, limit)).fetchall()
|
||||
|
||||
results = {}
|
||||
for row in rows:
|
||||
# Convert distance to similarity (1 - distance for cosine)
|
||||
similarity = max(0, 1 - row["distance"])
|
||||
chunk_id = row["chunk_id"]
|
||||
|
||||
# Apply filters post-hoc for vector search
|
||||
if tags or doc_type:
|
||||
check = conn.execute("""
|
||||
SELECT 1 FROM chunks c
|
||||
JOIN documents d ON c.document_id = d.id
|
||||
WHERE c.id = ?
|
||||
""" + (" AND d.doc_type = ?" if doc_type else ""),
|
||||
(chunk_id,) + ((doc_type,) if doc_type else ())
|
||||
).fetchone()
|
||||
if not check:
|
||||
continue
|
||||
|
||||
if tags:
|
||||
tag_count = conn.execute("""
|
||||
SELECT COUNT(*) FROM chunks c
|
||||
JOIN documents d ON c.document_id = d.id
|
||||
JOIN document_tags dt ON d.id = dt.document_id
|
||||
JOIN tags t ON dt.tag_id = t.id
|
||||
WHERE c.id = ? AND t.name IN ({})
|
||||
""".format(",".join("?" * len(tags))),
|
||||
(chunk_id, *[t.strip().lower() for t in tags])
|
||||
).fetchone()[0]
|
||||
if tag_count < len(tags):
|
||||
continue
|
||||
|
||||
results[chunk_id] = similarity
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _rrf_merge(fts_results: dict[int, float], vec_results: dict[int, float],
|
||||
k: int = 60) -> list[dict]:
|
||||
"""Merge two result sets using Reciprocal Rank Fusion."""
|
||||
# Rank each result set
|
||||
fts_ranked = _rank_results(fts_results)
|
||||
vec_ranked = _rank_results(vec_results)
|
||||
|
||||
all_ids = set(fts_ranked.keys()) | set(vec_ranked.keys())
|
||||
|
||||
merged = []
|
||||
for chunk_id in all_ids:
|
||||
fts_rank = fts_ranked.get(chunk_id)
|
||||
vec_rank = vec_ranked.get(chunk_id)
|
||||
|
||||
score = 0
|
||||
if fts_rank is not None:
|
||||
score += 1 / (k + fts_rank)
|
||||
if vec_rank is not None:
|
||||
score += 1 / (k + vec_rank)
|
||||
|
||||
fts_score = round(1 / (k + fts_rank), 6) if fts_rank is not None else None
|
||||
vec_score = round(1 / (k + vec_rank), 6) if vec_rank is not None else None
|
||||
|
||||
merged.append({
|
||||
"chunk_id": chunk_id,
|
||||
"score": score,
|
||||
"score_breakdown": {"fts": fts_score, "vector": vec_score},
|
||||
})
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _single_source_results(results: dict[int, float], source: str) -> list[dict]:
|
||||
"""Convert single-source results to merged format."""
|
||||
ranked = _rank_results(results)
|
||||
merged = []
|
||||
for chunk_id, rank in ranked.items():
|
||||
score = results[chunk_id]
|
||||
breakdown = {"fts": None, "vector": None}
|
||||
breakdown[source] = round(score, 6)
|
||||
merged.append({
|
||||
"chunk_id": chunk_id,
|
||||
"score": score,
|
||||
"score_breakdown": breakdown,
|
||||
})
|
||||
return merged
|
||||
|
||||
|
||||
def _rank_results(results: dict[int, float]) -> dict[int, int]:
|
||||
"""Rank results by score (1-indexed, higher score = lower rank number)."""
|
||||
sorted_ids = sorted(results.keys(), key=lambda x: results[x], reverse=True)
|
||||
return {chunk_id: rank + 1 for rank, chunk_id in enumerate(sorted_ids)}
|
||||
|
||||
|
||||
def _escape_fts_query(query: str) -> str:
|
||||
"""Escape special FTS5 characters in a query."""
|
||||
# Remove FTS5 operators that could cause syntax errors
|
||||
query = re.sub(r'["\(\)\*\:\^]', " ", query)
|
||||
# Collapse multiple spaces
|
||||
query = re.sub(r"\s+", " ", query).strip()
|
||||
return query
|
||||
Reference in New Issue
Block a user