Add bulk operations and remove collections abstraction
- Add bulk delete, bulk tags, and bulk set-tags engine endpoints (POST /api/v1/bulk/delete, /bulk/tags, /bulk/set-tags) - Filter-based selection: by tags, doc_type, ID list, ID range - Safety threshold (KB_BULK_SAFETY_PERCENT, default 70%) prevents accidental mass operations unless force=true - Synchronous execution with audit trail via jobs table - Add kb_bulk_delete, kb_bulk_tags, kb_bulk_set_tags MCP tools - Add kb bulk-remove, bulk-tag, bulk-set-tags CLI commands - Remove collection abstraction from MCP server (use tags instead) - Remove kb_set_collection MCP tool - Update SKILL.md, MCP.md, README.md documentation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -1 +1 @@
|
||||
3.0.1
|
||||
3.2.0
|
||||
|
||||
@@ -20,6 +20,7 @@ class Config:
|
||||
self.ingest_device = os.environ.get("KB_INGEST_DEVICE", "auto")
|
||||
self.api_key = os.environ.get("KB_API_KEY") or None
|
||||
self.search_threshold = float(os.environ.get("KB_SEARCH_THRESHOLD", "0.01"))
|
||||
self.bulk_safety_percent = int(os.environ.get("KB_BULK_SAFETY_PERCENT", "70"))
|
||||
self.host = os.environ.get("KB_HOST", "0.0.0.0")
|
||||
self.port = int(os.environ.get("KB_PORT", "8000"))
|
||||
|
||||
|
||||
@@ -189,6 +189,11 @@ def init_schema(conn: sqlite3.Connection, embedding_dim: int) -> None:
|
||||
if "updated_at" not in doc_cols:
|
||||
conn.execute("ALTER TABLE documents ADD COLUMN updated_at TEXT")
|
||||
|
||||
# Migrate: add job_type to jobs if missing (bulk operations)
|
||||
job_cols = {row[1] for row in conn.execute("PRAGMA table_info(jobs)").fetchall()}
|
||||
if "job_type" not in job_cols:
|
||||
conn.execute("ALTER TABLE jobs ADD COLUMN job_type TEXT DEFAULT 'ingest'")
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -329,6 +334,92 @@ def untag_document(conn: sqlite3.Connection, document_id: int, tag_names: list[s
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bulk operation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def resolve_bulk_selection(
|
||||
conn: sqlite3.Connection,
|
||||
document_ids: list[int] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
doc_type: str | None = None,
|
||||
from_id: int | None = None,
|
||||
to_id: int | None = None,
|
||||
) -> list[int]:
|
||||
"""Return document IDs matching the bulk selection filter.
|
||||
|
||||
Filters combine with AND logic. At least one filter must be provided.
|
||||
"""
|
||||
sql = "SELECT DISTINCT d.id FROM documents d"
|
||||
joins: list[str] = []
|
||||
where: list[str] = []
|
||||
params: list = []
|
||||
|
||||
if tags:
|
||||
for i, tag in enumerate(tags):
|
||||
joins.append(f"JOIN document_tags dt{i} ON d.id = dt{i}.document_id")
|
||||
joins.append(f"JOIN tags t{i} ON dt{i}.tag_id = t{i}.id")
|
||||
where.append(f"t{i}.name = ?")
|
||||
params.append(tag)
|
||||
|
||||
if doc_type:
|
||||
where.append("d.doc_type = ?")
|
||||
params.append(doc_type)
|
||||
|
||||
if document_ids:
|
||||
placeholders = ",".join("?" for _ in document_ids)
|
||||
where.append(f"d.id IN ({placeholders})")
|
||||
params.extend(document_ids)
|
||||
|
||||
if from_id is not None:
|
||||
where.append("d.id >= ?")
|
||||
params.append(from_id)
|
||||
|
||||
if to_id is not None:
|
||||
where.append("d.id <= ?")
|
||||
params.append(to_id)
|
||||
|
||||
if joins:
|
||||
sql += " " + " ".join(joins)
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
return [row["id"] for row in rows]
|
||||
|
||||
|
||||
def create_bulk_job(
|
||||
conn: sqlite3.Connection,
|
||||
job_type: str,
|
||||
filters_json: str,
|
||||
matched: int,
|
||||
succeeded: int,
|
||||
failed: int,
|
||||
errors_json: str = "[]",
|
||||
) -> int:
|
||||
"""Create an audit log entry for a bulk operation and return its id."""
|
||||
cur = conn.execute(
|
||||
"""INSERT INTO jobs(filename, status, job_type, document_id, chunk_count, error, completed_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, current_timestamp)""",
|
||||
(
|
||||
filters_json,
|
||||
"done" if failed == 0 else "partial_failure",
|
||||
job_type,
|
||||
matched,
|
||||
succeeded,
|
||||
errors_json if failed > 0 else None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def count_documents(conn: sqlite3.Connection) -> int:
|
||||
"""Return total number of documents in the database."""
|
||||
row = conn.execute("SELECT COUNT(*) AS cnt FROM documents").fetchone()
|
||||
return row["cnt"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vec table management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Bulk operation endpoints — delete, tag, and set-tags on multiple documents."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from main import app
|
||||
from kb.config import cfg
|
||||
from kb.database import (
|
||||
get_connection,
|
||||
resolve_bulk_selection,
|
||||
count_documents,
|
||||
create_bulk_job,
|
||||
tag_document,
|
||||
untag_document,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("kb.routes.bulk")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BulkSelectionRequest(BaseModel):
|
||||
document_ids: Optional[list[int]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
doc_type: Optional[str] = None
|
||||
from_id: Optional[int] = None
|
||||
to_id: Optional[int] = None
|
||||
force: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_filter(self):
|
||||
if not any([self.document_ids, self.tags, self.doc_type,
|
||||
self.from_id is not None, self.to_id is not None]):
|
||||
raise ValueError("At least one selection filter is required")
|
||||
return self
|
||||
|
||||
|
||||
class BulkDeleteRequest(BulkSelectionRequest):
|
||||
pass
|
||||
|
||||
|
||||
class BulkTagsRequest(BulkSelectionRequest):
|
||||
add: Optional[list[str]] = None
|
||||
remove: Optional[list[str]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_add_or_remove(self):
|
||||
if not self.add and not self.remove:
|
||||
raise ValueError("At least one of 'add' or 'remove' is required")
|
||||
return self
|
||||
|
||||
|
||||
class BulkSetTagsRequest(BulkSelectionRequest):
|
||||
new_tags: list[str]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_safety_threshold(matched: int, total: int, force: bool) -> None:
|
||||
"""Raise 409 if the operation would affect too many documents."""
|
||||
threshold = cfg.bulk_safety_percent
|
||||
if threshold <= 0 or force or total == 0:
|
||||
return
|
||||
percent = (matched / total) * 100
|
||||
if percent > threshold:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "safety_threshold_exceeded",
|
||||
"message": (
|
||||
f"Operation would affect {matched} of {total} documents "
|
||||
f"({percent:.1f}%). Exceeds safety threshold of {threshold}%. "
|
||||
f"Use force: true to proceed."
|
||||
),
|
||||
"matched": matched,
|
||||
"total": total,
|
||||
"percent": round(percent, 1),
|
||||
"threshold": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _filters_dict(req: BulkSelectionRequest) -> str:
|
||||
"""Build a JSON string of the selection filter for audit logging."""
|
||||
d = {}
|
||||
if req.document_ids:
|
||||
d["document_ids"] = req.document_ids
|
||||
if req.tags:
|
||||
d["tags"] = req.tags
|
||||
if req.doc_type:
|
||||
d["doc_type"] = req.doc_type
|
||||
if req.from_id is not None:
|
||||
d["from_id"] = req.from_id
|
||||
if req.to_id is not None:
|
||||
d["to_id"] = req.to_id
|
||||
return json.dumps(d)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/api/v1/bulk/delete")
|
||||
async def bulk_delete(req: BulkDeleteRequest):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc_ids = resolve_bulk_selection(
|
||||
conn, req.document_ids, req.tags, req.doc_type, req.from_id, req.to_id,
|
||||
)
|
||||
total = count_documents(conn)
|
||||
_check_safety_threshold(len(doc_ids), total, req.force)
|
||||
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
stored_files: list[str] = []
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
doc = conn.execute(
|
||||
"SELECT id, stored_path FROM documents WHERE id = ?", (doc_id,)
|
||||
).fetchone()
|
||||
if not doc:
|
||||
failed += 1
|
||||
errors.append({"document_id": doc_id, "error": "not found"})
|
||||
continue
|
||||
|
||||
if doc["stored_path"]:
|
||||
stored_files.append(doc["stored_path"])
|
||||
|
||||
# Delete embeddings
|
||||
chunk_ids = conn.execute(
|
||||
"SELECT id FROM chunks WHERE document_id = ?", (doc_id,)
|
||||
).fetchall()
|
||||
for row in chunk_ids:
|
||||
conn.execute("DELETE FROM chunks_vec WHERE chunk_id = ?", (row["id"],))
|
||||
|
||||
# Delete document (cascades to chunks, document_tags)
|
||||
conn.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
||||
succeeded += 1
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
errors.append({"document_id": doc_id, "error": str(exc)})
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Best-effort file cleanup after commit
|
||||
for path in stored_files:
|
||||
try:
|
||||
f = Path(path)
|
||||
if f.exists():
|
||||
f.unlink()
|
||||
except OSError as exc:
|
||||
logger.warning("Failed to delete stored file %s: %s", path, exc)
|
||||
|
||||
errors_json = json.dumps(errors) if errors else "[]"
|
||||
job_id = create_bulk_job(
|
||||
conn, "bulk_delete", _filters_dict(req),
|
||||
len(doc_ids), succeeded, failed, errors_json,
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "done" if failed == 0 else "partial_failure",
|
||||
"matched": len(doc_ids),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.post("/api/v1/bulk/tags")
|
||||
async def bulk_tags(req: BulkTagsRequest):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc_ids = resolve_bulk_selection(
|
||||
conn, req.document_ids, req.tags, req.doc_type, req.from_id, req.to_id,
|
||||
)
|
||||
total = count_documents(conn)
|
||||
_check_safety_threshold(len(doc_ids), total, req.force)
|
||||
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if req.add:
|
||||
tag_document(conn, doc_id, req.add)
|
||||
if req.remove:
|
||||
untag_document(conn, doc_id, req.remove)
|
||||
conn.execute(
|
||||
"UPDATE documents SET updated_at = current_timestamp WHERE id = ?",
|
||||
(doc_id,),
|
||||
)
|
||||
succeeded += 1
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
errors.append({"document_id": doc_id, "error": str(exc)})
|
||||
|
||||
conn.commit()
|
||||
|
||||
errors_json = json.dumps(errors) if errors else "[]"
|
||||
job_id = create_bulk_job(
|
||||
conn, "bulk_tags", _filters_dict(req),
|
||||
len(doc_ids), succeeded, failed, errors_json,
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "done" if failed == 0 else "partial_failure",
|
||||
"matched": len(doc_ids),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@app.post("/api/v1/bulk/set-tags")
|
||||
async def bulk_set_tags(req: BulkSetTagsRequest):
|
||||
conn = get_connection(cfg.db_path)
|
||||
try:
|
||||
doc_ids = resolve_bulk_selection(
|
||||
conn, req.document_ids, req.tags, req.doc_type, req.from_id, req.to_id,
|
||||
)
|
||||
total = count_documents(conn)
|
||||
_check_safety_threshold(len(doc_ids), total, req.force)
|
||||
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
# Remove all existing tags
|
||||
conn.execute(
|
||||
"DELETE FROM document_tags WHERE document_id = ?", (doc_id,)
|
||||
)
|
||||
# Apply new tag set
|
||||
if req.new_tags:
|
||||
tag_document(conn, doc_id, req.new_tags)
|
||||
conn.execute(
|
||||
"UPDATE documents SET updated_at = current_timestamp WHERE id = ?",
|
||||
(doc_id,),
|
||||
)
|
||||
succeeded += 1
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
errors.append({"document_id": doc_id, "error": str(exc)})
|
||||
|
||||
conn.commit()
|
||||
|
||||
errors_json = json.dumps(errors) if errors else "[]"
|
||||
job_id = create_bulk_job(
|
||||
conn, "bulk_set_tags", _filters_dict(req),
|
||||
len(doc_ids), succeeded, failed, errors_json,
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "done" if failed == 0 else "partial_failure",
|
||||
"matched": len(doc_ids),
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
+1
-1
@@ -62,7 +62,7 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(title="kb-engine", version=__version__, lifespan=lifespan)
|
||||
|
||||
# Import routes after app is created
|
||||
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth, notes # noqa: E402, F401
|
||||
from kb.routes import health, search, jobs, documents, tags, status, reindex, auth, notes, bulk # noqa: E402, F401
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user