Files
kb/mcp/server.py
T
steve da5b8435bc Add configurable allowed hosts for MCP remote access (KB_MCP_ALLOWED_HOSTS)
The MCP SDK's DNS rebinding protection rejects remote clients with 421
when the Host header isn't in the allowlist. Add KB_MCP_ALLOWED_HOSTS env
var (comma-separated IPs/FQDNs) to configure additional allowed hosts
while keeping localhost always permitted.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 12:39:43 +01:00

404 lines
14 KiB
Python

"""kb MCP server — exposes knowledge base operations as MCP tools."""
import asyncio
import json
import logging
from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
import config
import engine
import uploads
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("kb.mcp")
# ---------------------------------------------------------------------------
# Collection helpers
# ---------------------------------------------------------------------------
COLLECTION_TAG_PREFIX = "collection:"
DEFAULT_COLLECTION = "documents"
def _collection_tag(collection: str | None) -> str:
return f"{COLLECTION_TAG_PREFIX}{collection or DEFAULT_COLLECTION}"
def _strip_collection_tags(tags: list[str]) -> tuple[str | None, list[str]]:
"""Split tags into (collection, remaining_tags)."""
collection = None
remaining = []
for t in tags:
if t.startswith(COLLECTION_TAG_PREFIX):
collection = t[len(COLLECTION_TAG_PREFIX):]
else:
remaining.append(t)
return collection, remaining
def _process_document(doc: dict) -> dict:
"""Strip collection tags from a document dict and add collection field."""
tags = doc.get("tags", [])
collection, clean_tags = _strip_collection_tags(tags)
doc["tags"] = clean_tags
doc["collection"] = collection
return doc
def _process_search_results(results: list[dict]) -> list[dict]:
"""Strip collection tags from search result dicts."""
for r in results:
if "tags" in r:
collection, clean_tags = _strip_collection_tags(r["tags"])
r["tags"] = clean_tags
r["collection"] = collection
if "document" in r and "tags" in r["document"]:
collection, clean_tags = _strip_collection_tags(r["document"]["tags"])
r["document"]["tags"] = clean_tags
r["document"]["collection"] = collection
return results
async def _ensure_exclusive_collection(doc_id: int, collection: str) -> None:
"""Remove existing collection tags and apply the new one."""
doc = engine.get_document(doc_id)
existing_collection_tags = [
t for t in doc.get("tags", [])
if t.startswith(COLLECTION_TAG_PREFIX)
]
new_tag = _collection_tag(collection)
if existing_collection_tags == [new_tag]:
return
if existing_collection_tags:
engine.update_tags(doc_id, remove=existing_collection_tags)
engine.update_tags(doc_id, add=[new_tag])
# ---------------------------------------------------------------------------
# Transport security — DNS rebinding protection with configurable allowed hosts
# ---------------------------------------------------------------------------
_LOCALHOST_HOSTS = ["127.0.0.1:*", "localhost:*", "[::1]:*"]
_LOCALHOST_ORIGINS = ["http://127.0.0.1:*", "http://localhost:*", "http://[::1]:*"]
_extra_hosts = config.parse_allowed_hosts()
_allowed_hosts = _LOCALHOST_HOSTS + [f"{h}:*" for h in _extra_hosts]
_allowed_origins = _LOCALHOST_ORIGINS + [f"http://{h}:*" for h in _extra_hosts]
_transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=_allowed_hosts,
allowed_origins=_allowed_origins,
)
# ---------------------------------------------------------------------------
# FastMCP server
# ---------------------------------------------------------------------------
mcp = FastMCP(
"kb",
instructions=(
"Knowledge base MCP server. Provides tools for searching, adding, and "
"managing documents and notes. This server requires Bearer token "
"authentication — all requests are authenticated via the Authorization "
"header at the HTTP transport layer."
),
transport_security=_transport_security,
)
@mcp.tool()
async def kb_search(
query: str,
top: int = 10,
tags: list[str] | None = None,
doc_type: str | None = None,
collection: str | None = None,
fts_only: bool = False,
) -> str:
"""Search the knowledge base for relevant documents and notes.
Returns ranked chunks matching the query, with text content, relevance scores,
and document metadata.
Args:
query: The search query. Can be a natural language question or keywords.
top: Maximum number of results to return (default 10).
tags: Filter results to documents with ALL of these tags.
doc_type: Filter by document type (e.g. "note", "pdf", "markdown", "code").
collection: Filter by collection name (e.g. "documents", "memory", "workspace").
fts_only: If true, use only full-text search (no vector similarity).
Tips for complex queries:
- Consider expanding into 2-3 variant phrasings and calling this tool multiple
times, then deduplicating results by chunk_id. For example, search for both
"pension revaluation rules" and "how are pensions revalued" to cast a wider net.
- For precision, rerank the returned results using your own judgement based on
relevance to the original question.
"""
search_tags = list(tags) if tags else []
if collection:
search_tags.append(_collection_tag(collection))
result = engine.search(
query=query,
top=top,
tags=search_tags or None,
doc_type=doc_type,
fts_only=fts_only,
)
results_list = result if isinstance(result, list) else result.get("results", [])
processed = _process_search_results(results_list)
return json.dumps(processed, indent=2)
@mcp.tool()
async def kb_addnote(
text: str,
collection: str | None = None,
tags: list[str] | None = None,
title: str | None = None,
) -> str:
"""Add a text note to the knowledge base for indexing and search.
The note is queued for ingestion — it will be chunked, embedded, and made
searchable. Use kb_jobs to check ingestion status.
Args:
text: The note text content.
collection: Collection to add the note to (default "documents").
Standard collections: "documents", "memory", "workspace".
tags: Additional tags to apply to the note.
title: Optional title (auto-derived from first line if omitted).
"""
all_tags = list(tags) if tags else []
all_tags.append(_collection_tag(collection))
result = engine.add_note(text=text, tags=all_tags, title=title)
return json.dumps(result, indent=2)
@mcp.tool()
async def kb_update_note(
document_id: int,
text: str,
) -> str:
"""Update an existing note's content in place.
Replaces the note text, re-chunks, and re-embeds while preserving the
document ID, creation timestamp, and tags. Only works on documents with
doc_type "note".
Args:
document_id: The ID of the note document to update.
text: The new text content for the note.
"""
result = engine.update_note(document_id, text)
return json.dumps(_process_document(result), indent=2)
@mcp.tool()
async def kb_get(
document_id: int | None = None,
source_path: str | None = None,
) -> str:
"""Retrieve document details from the knowledge base.
Look up a document by its ID or source path. Returns full document metadata,
tags, and chunk contents.
Args:
document_id: The numeric document ID.
source_path: The document's source path (alternative to document_id).
"""
if document_id is not None:
result = engine.get_document(document_id)
return json.dumps(_process_document(result), indent=2)
elif source_path is not None:
docs = engine.list_documents()
matches = [d for d in docs if d.get("source_path") == source_path]
if not matches:
return json.dumps({"error": "No document found with that source_path"})
doc = engine.get_document(matches[0]["id"])
return json.dumps(_process_document(doc), indent=2)
else:
return json.dumps({"error": "Provide either document_id or source_path"})
@mcp.tool()
async def kb_status() -> str:
"""Get knowledge base engine status.
Returns engine version, embedding model info, device info, document counts,
database size, and ingestion queue state.
"""
result = engine.get_status()
result["authenticated"] = bool(config.KB_MCP_API_KEY)
return json.dumps(result, indent=2)
@mcp.tool()
async def kb_jobs(
status: str | None = None,
) -> str:
"""List ingestion jobs and their status.
Returns recent jobs showing what has been queued, is processing, completed,
or failed.
Args:
status: Filter by job status ("queued", "processing", "done", "failed", "skipped").
"""
result = engine.list_jobs(status=status)
return json.dumps(result, indent=2)
@mcp.tool()
async def kb_upload_start(
filename: str,
total_size: int,
tags: list[str] | None = None,
collection: str | None = None,
) -> str:
"""Start a chunked file upload to the knowledge base.
Use this for uploading files from a remote agent. The upload process is:
1. Call kb_upload_start to get an upload_id
2. Call kb_upload_chunk repeatedly with base64-encoded file chunks (recommended ~1MB each)
3. Call kb_upload_finish to submit the file for ingestion
Example for a 3MB file:
upload = kb_upload_start(filename="report.pdf", total_size=3145728, collection="documents")
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 0>", chunk_index=0)
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 1>", chunk_index=1)
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 2>", chunk_index=2)
result = kb_upload_finish(upload_id=upload["upload_id"])
Args:
filename: Original filename (used for type detection).
total_size: Total file size in bytes.
tags: Additional tags to apply.
collection: Collection name (default "documents").
"""
all_tags = list(tags) if tags else []
all_tags.append(_collection_tag(collection))
upload_id = uploads.start_upload(filename, total_size, all_tags)
return json.dumps({"upload_id": upload_id})
@mcp.tool()
async def kb_upload_chunk(
upload_id: str,
data: str,
chunk_index: int,
) -> str:
"""Upload a base64-encoded chunk of a file.
Part of the chunked upload flow started by kb_upload_start.
Args:
upload_id: The upload ID from kb_upload_start.
data: Base64-encoded file data for this chunk.
chunk_index: Zero-based index of this chunk.
"""
try:
uploads.add_chunk(upload_id, data, chunk_index)
return json.dumps({"status": "ok", "chunk_index": chunk_index})
except KeyError as e:
return json.dumps({"error": str(e)})
@mcp.tool()
async def kb_upload_finish(
upload_id: str,
) -> str:
"""Finish a chunked upload and submit the file for ingestion.
Reassembles all uploaded chunks and forwards the complete file to the
engine for processing. Returns the ingestion job ID.
Args:
upload_id: The upload ID from kb_upload_start.
"""
try:
filename, file_bytes, tags = uploads.finish_upload(upload_id)
result = engine.upload_file(filename, file_bytes, tags)
return json.dumps(result, indent=2)
except KeyError as e:
return json.dumps({"error": str(e)})
# ---------------------------------------------------------------------------
# Auth middleware
# ---------------------------------------------------------------------------
class BearerAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not config.KB_MCP_API_KEY:
return await call_next(request)
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer ") and auth_header[7:] == config.KB_MCP_API_KEY:
return await call_next(request)
return JSONResponse(
status_code=401,
content={"error": "Unauthorized"},
)
# ---------------------------------------------------------------------------
# ASGI app assembly
# ---------------------------------------------------------------------------
def create_app():
"""Create the ASGI app with auth middleware wrapping the MCP server."""
from contextlib import asynccontextmanager
mcp_app = mcp.streamable_http_app()
@asynccontextmanager
async def lifespan(app):
uploads.start_cleanup_task()
logger.info("Upload cleanup task started")
# Delegate to the MCP app's lifespan if it has one
if hasattr(mcp_app, 'router') and hasattr(mcp_app.router, 'lifespan_context'):
async with mcp_app.router.lifespan_context(app):
yield
else:
yield
app = Starlette(
routes=[Mount("/", app=mcp_app)],
middleware=[Middleware(BearerAuthMiddleware)],
lifespan=lifespan,
)
return app
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
logger.info(
"Starting kb MCP server on port %d, engine=%s",
config.KB_MCP_PORT,
config.KB_ENGINE_URL,
)
app = create_app()
uvicorn.run(app, host="0.0.0.0", port=config.KB_MCP_PORT)