"""kb MCP server — exposes knowledge base operations as MCP tools.""" import asyncio import json import logging from mcp.server.fastmcp import FastMCP from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount import config import engine import uploads logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("kb.mcp") # --------------------------------------------------------------------------- # Collection helpers # --------------------------------------------------------------------------- COLLECTION_TAG_PREFIX = "collection:" DEFAULT_COLLECTION = "documents" def _collection_tag(collection: str | None) -> str: return f"{COLLECTION_TAG_PREFIX}{collection or DEFAULT_COLLECTION}" def _strip_collection_tags(tags: list[str]) -> tuple[str | None, list[str]]: """Split tags into (collection, remaining_tags).""" collection = None remaining = [] for t in tags: if t.startswith(COLLECTION_TAG_PREFIX): collection = t[len(COLLECTION_TAG_PREFIX):] else: remaining.append(t) return collection, remaining def _process_document(doc: dict) -> dict: """Strip collection tags from a document dict and add collection field.""" tags = doc.get("tags", []) collection, clean_tags = _strip_collection_tags(tags) doc["tags"] = clean_tags doc["collection"] = collection return doc def _process_search_results(results: list[dict]) -> list[dict]: """Strip collection tags from search result dicts.""" for r in results: if "tags" in r: collection, clean_tags = _strip_collection_tags(r["tags"]) r["tags"] = clean_tags r["collection"] = collection if "document" in r and "tags" in r["document"]: collection, clean_tags = _strip_collection_tags(r["document"]["tags"]) r["document"]["tags"] = clean_tags r["document"]["collection"] = collection return results async def _ensure_exclusive_collection(doc_id: int, collection: str) -> None: """Remove existing collection tags and apply the new one.""" doc = engine.get_document(doc_id) existing_collection_tags = [ t for t in doc.get("tags", []) if t.startswith(COLLECTION_TAG_PREFIX) ] new_tag = _collection_tag(collection) if existing_collection_tags == [new_tag]: return if existing_collection_tags: engine.update_tags(doc_id, remove=existing_collection_tags) engine.update_tags(doc_id, add=[new_tag]) # --------------------------------------------------------------------------- # FastMCP server # --------------------------------------------------------------------------- mcp = FastMCP( "kb", instructions=( "Knowledge base MCP server. Provides tools for searching, adding, and " "managing documents and notes. This server requires Bearer token " "authentication — all requests are authenticated via the Authorization " "header at the HTTP transport layer." ), ) @mcp.tool() async def kb_search( query: str, top: int = 10, tags: list[str] | None = None, doc_type: str | None = None, collection: str | None = None, fts_only: bool = False, ) -> str: """Search the knowledge base for relevant documents and notes. Returns ranked chunks matching the query, with text content, relevance scores, and document metadata. Args: query: The search query. Can be a natural language question or keywords. top: Maximum number of results to return (default 10). tags: Filter results to documents with ALL of these tags. doc_type: Filter by document type (e.g. "note", "pdf", "markdown", "code"). collection: Filter by collection name (e.g. "documents", "memory", "workspace"). fts_only: If true, use only full-text search (no vector similarity). Tips for complex queries: - Consider expanding into 2-3 variant phrasings and calling this tool multiple times, then deduplicating results by chunk_id. For example, search for both "pension revaluation rules" and "how are pensions revalued" to cast a wider net. - For precision, rerank the returned results using your own judgement based on relevance to the original question. """ search_tags = list(tags) if tags else [] if collection: search_tags.append(_collection_tag(collection)) result = engine.search( query=query, top=top, tags=search_tags or None, doc_type=doc_type, fts_only=fts_only, ) results_list = result if isinstance(result, list) else result.get("results", []) processed = _process_search_results(results_list) return json.dumps(processed, indent=2) @mcp.tool() async def kb_addnote( text: str, collection: str | None = None, tags: list[str] | None = None, title: str | None = None, ) -> str: """Add a text note to the knowledge base for indexing and search. The note is queued for ingestion — it will be chunked, embedded, and made searchable. Use kb_jobs to check ingestion status. Args: text: The note text content. collection: Collection to add the note to (default "documents"). Standard collections: "documents", "memory", "workspace". tags: Additional tags to apply to the note. title: Optional title (auto-derived from first line if omitted). """ all_tags = list(tags) if tags else [] all_tags.append(_collection_tag(collection)) result = engine.add_note(text=text, tags=all_tags, title=title) return json.dumps(result, indent=2) @mcp.tool() async def kb_update_note( document_id: int, text: str, ) -> str: """Update an existing note's content in place. Replaces the note text, re-chunks, and re-embeds while preserving the document ID, creation timestamp, and tags. Only works on documents with doc_type "note". Args: document_id: The ID of the note document to update. text: The new text content for the note. """ result = engine.update_note(document_id, text) return json.dumps(_process_document(result), indent=2) @mcp.tool() async def kb_get( document_id: int | None = None, source_path: str | None = None, ) -> str: """Retrieve document details from the knowledge base. Look up a document by its ID or source path. Returns full document metadata, tags, and chunk contents. Args: document_id: The numeric document ID. source_path: The document's source path (alternative to document_id). """ if document_id is not None: result = engine.get_document(document_id) return json.dumps(_process_document(result), indent=2) elif source_path is not None: docs = engine.list_documents() matches = [d for d in docs if d.get("source_path") == source_path] if not matches: return json.dumps({"error": "No document found with that source_path"}) doc = engine.get_document(matches[0]["id"]) return json.dumps(_process_document(doc), indent=2) else: return json.dumps({"error": "Provide either document_id or source_path"}) @mcp.tool() async def kb_status() -> str: """Get knowledge base engine status. Returns engine version, embedding model info, device info, document counts, database size, and ingestion queue state. """ result = engine.get_status() result["authenticated"] = bool(config.KB_MCP_API_KEY) return json.dumps(result, indent=2) @mcp.tool() async def kb_jobs( status: str | None = None, ) -> str: """List ingestion jobs and their status. Returns recent jobs showing what has been queued, is processing, completed, or failed. Args: status: Filter by job status ("queued", "processing", "done", "failed", "skipped"). """ result = engine.list_jobs(status=status) return json.dumps(result, indent=2) @mcp.tool() async def kb_upload_start( filename: str, total_size: int, tags: list[str] | None = None, collection: str | None = None, ) -> str: """Start a chunked file upload to the knowledge base. Use this for uploading files from a remote agent. The upload process is: 1. Call kb_upload_start to get an upload_id 2. Call kb_upload_chunk repeatedly with base64-encoded file chunks (recommended ~1MB each) 3. Call kb_upload_finish to submit the file for ingestion Example for a 3MB file: upload = kb_upload_start(filename="report.pdf", total_size=3145728, collection="documents") kb_upload_chunk(upload_id=upload["upload_id"], data="", chunk_index=0) kb_upload_chunk(upload_id=upload["upload_id"], data="", chunk_index=1) kb_upload_chunk(upload_id=upload["upload_id"], data="", chunk_index=2) result = kb_upload_finish(upload_id=upload["upload_id"]) Args: filename: Original filename (used for type detection). total_size: Total file size in bytes. tags: Additional tags to apply. collection: Collection name (default "documents"). """ all_tags = list(tags) if tags else [] all_tags.append(_collection_tag(collection)) upload_id = uploads.start_upload(filename, total_size, all_tags) return json.dumps({"upload_id": upload_id}) @mcp.tool() async def kb_upload_chunk( upload_id: str, data: str, chunk_index: int, ) -> str: """Upload a base64-encoded chunk of a file. Part of the chunked upload flow started by kb_upload_start. Args: upload_id: The upload ID from kb_upload_start. data: Base64-encoded file data for this chunk. chunk_index: Zero-based index of this chunk. """ try: uploads.add_chunk(upload_id, data, chunk_index) return json.dumps({"status": "ok", "chunk_index": chunk_index}) except KeyError as e: return json.dumps({"error": str(e)}) @mcp.tool() async def kb_upload_finish( upload_id: str, ) -> str: """Finish a chunked upload and submit the file for ingestion. Reassembles all uploaded chunks and forwards the complete file to the engine for processing. Returns the ingestion job ID. Args: upload_id: The upload ID from kb_upload_start. """ try: filename, file_bytes, tags = uploads.finish_upload(upload_id) result = engine.upload_file(filename, file_bytes, tags) return json.dumps(result, indent=2) except KeyError as e: return json.dumps({"error": str(e)}) # --------------------------------------------------------------------------- # Auth middleware # --------------------------------------------------------------------------- class BearerAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not config.KB_MCP_API_KEY: return await call_next(request) auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer ") and auth_header[7:] == config.KB_MCP_API_KEY: return await call_next(request) return JSONResponse( status_code=401, content={"error": "Unauthorized"}, ) # --------------------------------------------------------------------------- # ASGI app assembly # --------------------------------------------------------------------------- def create_app(): """Create the ASGI app with auth middleware wrapping the MCP server.""" from contextlib import asynccontextmanager mcp_app = mcp.streamable_http_app() @asynccontextmanager async def lifespan(app): uploads.start_cleanup_task() logger.info("Upload cleanup task started") # Delegate to the MCP app's lifespan if it has one if hasattr(mcp_app, 'router') and hasattr(mcp_app.router, 'lifespan_context'): async with mcp_app.router.lifespan_context(app): yield else: yield app = Starlette( routes=[Mount("/", app=mcp_app)], middleware=[Middleware(BearerAuthMiddleware)], lifespan=lifespan, ) return app # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn logger.info( "Starting kb MCP server on port %d, engine=%s", config.KB_MCP_PORT, config.KB_ENGINE_URL, ) app = create_app() uvicorn.run(app, host="0.0.0.0", port=config.KB_MCP_PORT)