Add MCP server, note mutation endpoint, and updated_at tracking (v3.0.0)
New MCP server (mcp/) exposes kb operations as native MCP tools over
Streamable HTTP with Bearer token auth. Supports collections via tag
conventions, chunked file uploads, and agent-side search patterns.
Engine gains PATCH /api/v1/notes/{id} for in-place note updates with
transactional re-chunk/re-embed, and updated_at column on documents.
Go client adds updatenote command and Patch HTTP method.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY *.py ./
|
||||
|
||||
ENV KB_ENGINE_URL=http://engine:8000
|
||||
ENV KB_API_KEY=
|
||||
ENV KB_MCP_API_KEY=
|
||||
ENV KB_MCP_PORT=3000
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
CMD ["python", "server.py"]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Configuration from environment variables."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
KB_ENGINE_URL = os.environ.get("KB_ENGINE_URL", "http://localhost:8000")
|
||||
KB_API_KEY = os.environ.get("KB_API_KEY", "")
|
||||
KB_MCP_API_KEY = os.environ.get("KB_MCP_API_KEY", "")
|
||||
KB_MCP_PORT = int(os.environ.get("KB_MCP_PORT", "3000"))
|
||||
+121
@@ -0,0 +1,121 @@
|
||||
"""HTTP client for the kb engine API."""
|
||||
|
||||
import httpx
|
||||
|
||||
from config import KB_ENGINE_URL, KB_API_KEY
|
||||
|
||||
|
||||
def _auth_headers() -> dict[str, str]:
|
||||
h: dict[str, str] = {}
|
||||
if KB_API_KEY:
|
||||
h["Authorization"] = f"Bearer {KB_API_KEY}"
|
||||
return h
|
||||
|
||||
|
||||
def _client() -> httpx.Client:
|
||||
return httpx.Client(base_url=KB_ENGINE_URL, headers=_auth_headers(), timeout=60.0)
|
||||
|
||||
|
||||
def search(query: str, top: int = 10, tags: list[str] | None = None,
|
||||
doc_type: str | None = None, fts_only: bool = False,
|
||||
vec_only: bool = False, threshold: float | None = None) -> dict:
|
||||
body: dict = {"query": query, "top": top}
|
||||
if tags:
|
||||
body["tags"] = tags
|
||||
if doc_type:
|
||||
body["doc_type"] = doc_type
|
||||
if fts_only:
|
||||
body["fts_only"] = True
|
||||
if vec_only:
|
||||
body["vec_only"] = True
|
||||
if threshold is not None:
|
||||
body["threshold"] = threshold
|
||||
with _client() as c:
|
||||
r = c.post("/api/v1/search", json=body)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def add_note(text: str, tags: list[str] | None = None,
|
||||
title: str | None = None) -> dict:
|
||||
fields = {"note": text}
|
||||
if tags:
|
||||
fields["tags"] = ",".join(tags)
|
||||
if title:
|
||||
fields["title"] = title
|
||||
with _client() as c:
|
||||
r = c.post("/api/v1/jobs", data=fields)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def update_note(doc_id: int, text: str) -> dict:
|
||||
with _client() as c:
|
||||
r = c.patch(f"/api/v1/notes/{doc_id}", json={"text": text})
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def get_document(doc_id: int) -> dict:
|
||||
with _client() as c:
|
||||
r = c.get(f"/api/v1/documents/{doc_id}")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def list_documents(doc_type: str | None = None,
|
||||
tags: str | None = None) -> list[dict]:
|
||||
params: dict = {}
|
||||
if doc_type:
|
||||
params["type"] = doc_type
|
||||
if tags:
|
||||
params["tags"] = tags
|
||||
with _client() as c:
|
||||
r = c.get("/api/v1/documents", params=params)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def get_status() -> dict:
|
||||
with _client() as c:
|
||||
r = c.get("/api/v1/status")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def list_jobs(status: str | None = None) -> list[dict]:
|
||||
params: dict = {}
|
||||
if status:
|
||||
params["status"] = status
|
||||
with _client() as c:
|
||||
r = c.get("/api/v1/jobs", params=params)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def update_tags(doc_id: int, add: list[str] | None = None,
|
||||
remove: list[str] | None = None) -> dict:
|
||||
body: dict = {}
|
||||
if add:
|
||||
body["add"] = add
|
||||
if remove:
|
||||
body["remove"] = remove
|
||||
with _client() as c:
|
||||
r = c.put(f"/api/v1/documents/{doc_id}/tags", json=body)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def upload_file(filename: str, file_bytes: bytes,
|
||||
tags: list[str] | None = None) -> dict:
|
||||
fields: dict = {}
|
||||
if tags:
|
||||
fields["tags"] = ",".join(tags)
|
||||
with _client() as c:
|
||||
r = c.post(
|
||||
"/api/v1/jobs",
|
||||
data=fields,
|
||||
files={"file": (filename, file_bytes)},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
@@ -0,0 +1,4 @@
|
||||
mcp>=1.9.0
|
||||
httpx>=0.27
|
||||
uvicorn>=0.30
|
||||
starlette>=0.38
|
||||
+380
@@ -0,0 +1,380 @@
|
||||
"""kb MCP server — exposes knowledge base operations as MCP tools."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Mount
|
||||
|
||||
import config
|
||||
import engine
|
||||
import uploads
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger("kb.mcp")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COLLECTION_TAG_PREFIX = "collection:"
|
||||
DEFAULT_COLLECTION = "documents"
|
||||
|
||||
|
||||
def _collection_tag(collection: str | None) -> str:
|
||||
return f"{COLLECTION_TAG_PREFIX}{collection or DEFAULT_COLLECTION}"
|
||||
|
||||
|
||||
def _strip_collection_tags(tags: list[str]) -> tuple[str | None, list[str]]:
|
||||
"""Split tags into (collection, remaining_tags)."""
|
||||
collection = None
|
||||
remaining = []
|
||||
for t in tags:
|
||||
if t.startswith(COLLECTION_TAG_PREFIX):
|
||||
collection = t[len(COLLECTION_TAG_PREFIX):]
|
||||
else:
|
||||
remaining.append(t)
|
||||
return collection, remaining
|
||||
|
||||
|
||||
def _process_document(doc: dict) -> dict:
|
||||
"""Strip collection tags from a document dict and add collection field."""
|
||||
tags = doc.get("tags", [])
|
||||
collection, clean_tags = _strip_collection_tags(tags)
|
||||
doc["tags"] = clean_tags
|
||||
doc["collection"] = collection
|
||||
return doc
|
||||
|
||||
|
||||
def _process_search_results(results: list[dict]) -> list[dict]:
|
||||
"""Strip collection tags from search result dicts."""
|
||||
for r in results:
|
||||
if "tags" in r:
|
||||
collection, clean_tags = _strip_collection_tags(r["tags"])
|
||||
r["tags"] = clean_tags
|
||||
r["collection"] = collection
|
||||
if "document" in r and "tags" in r["document"]:
|
||||
collection, clean_tags = _strip_collection_tags(r["document"]["tags"])
|
||||
r["document"]["tags"] = clean_tags
|
||||
r["document"]["collection"] = collection
|
||||
return results
|
||||
|
||||
|
||||
async def _ensure_exclusive_collection(doc_id: int, collection: str) -> None:
|
||||
"""Remove existing collection tags and apply the new one."""
|
||||
doc = engine.get_document(doc_id)
|
||||
existing_collection_tags = [
|
||||
t for t in doc.get("tags", [])
|
||||
if t.startswith(COLLECTION_TAG_PREFIX)
|
||||
]
|
||||
new_tag = _collection_tag(collection)
|
||||
if existing_collection_tags == [new_tag]:
|
||||
return
|
||||
if existing_collection_tags:
|
||||
engine.update_tags(doc_id, remove=existing_collection_tags)
|
||||
engine.update_tags(doc_id, add=[new_tag])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastMCP server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
mcp = FastMCP(
|
||||
"kb",
|
||||
instructions="Knowledge base MCP server. Provides tools for searching, adding, and managing documents and notes.",
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_search(
|
||||
query: str,
|
||||
top: int = 10,
|
||||
tags: list[str] | None = None,
|
||||
doc_type: str | None = None,
|
||||
collection: str | None = None,
|
||||
fts_only: bool = False,
|
||||
) -> str:
|
||||
"""Search the knowledge base for relevant documents and notes.
|
||||
|
||||
Returns ranked chunks matching the query, with text content, relevance scores,
|
||||
and document metadata.
|
||||
|
||||
Args:
|
||||
query: The search query. Can be a natural language question or keywords.
|
||||
top: Maximum number of results to return (default 10).
|
||||
tags: Filter results to documents with ALL of these tags.
|
||||
doc_type: Filter by document type (e.g. "note", "pdf", "markdown", "code").
|
||||
collection: Filter by collection name (e.g. "documents", "memory", "workspace").
|
||||
fts_only: If true, use only full-text search (no vector similarity).
|
||||
|
||||
Tips for complex queries:
|
||||
- Consider expanding into 2-3 variant phrasings and calling this tool multiple
|
||||
times, then deduplicating results by chunk_id. For example, search for both
|
||||
"pension revaluation rules" and "how are pensions revalued" to cast a wider net.
|
||||
- For precision, rerank the returned results using your own judgement based on
|
||||
relevance to the original question.
|
||||
"""
|
||||
search_tags = list(tags) if tags else []
|
||||
if collection:
|
||||
search_tags.append(_collection_tag(collection))
|
||||
|
||||
result = engine.search(
|
||||
query=query,
|
||||
top=top,
|
||||
tags=search_tags or None,
|
||||
doc_type=doc_type,
|
||||
fts_only=fts_only,
|
||||
)
|
||||
|
||||
results_list = result if isinstance(result, list) else result.get("results", [])
|
||||
processed = _process_search_results(results_list)
|
||||
return json.dumps(processed, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_addnote(
|
||||
text: str,
|
||||
collection: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
title: str | None = None,
|
||||
) -> str:
|
||||
"""Add a text note to the knowledge base for indexing and search.
|
||||
|
||||
The note is queued for ingestion — it will be chunked, embedded, and made
|
||||
searchable. Use kb_jobs to check ingestion status.
|
||||
|
||||
Args:
|
||||
text: The note text content.
|
||||
collection: Collection to add the note to (default "documents").
|
||||
Standard collections: "documents", "memory", "workspace".
|
||||
tags: Additional tags to apply to the note.
|
||||
title: Optional title (auto-derived from first line if omitted).
|
||||
"""
|
||||
all_tags = list(tags) if tags else []
|
||||
all_tags.append(_collection_tag(collection))
|
||||
|
||||
result = engine.add_note(text=text, tags=all_tags, title=title)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_update_note(
|
||||
document_id: int,
|
||||
text: str,
|
||||
) -> str:
|
||||
"""Update an existing note's content in place.
|
||||
|
||||
Replaces the note text, re-chunks, and re-embeds while preserving the
|
||||
document ID, creation timestamp, and tags. Only works on documents with
|
||||
doc_type "note".
|
||||
|
||||
Args:
|
||||
document_id: The ID of the note document to update.
|
||||
text: The new text content for the note.
|
||||
"""
|
||||
result = engine.update_note(document_id, text)
|
||||
return json.dumps(_process_document(result), indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_get(
|
||||
document_id: int | None = None,
|
||||
source_path: str | None = None,
|
||||
) -> str:
|
||||
"""Retrieve document details from the knowledge base.
|
||||
|
||||
Look up a document by its ID or source path. Returns full document metadata,
|
||||
tags, and chunk contents.
|
||||
|
||||
Args:
|
||||
document_id: The numeric document ID.
|
||||
source_path: The document's source path (alternative to document_id).
|
||||
"""
|
||||
if document_id is not None:
|
||||
result = engine.get_document(document_id)
|
||||
return json.dumps(_process_document(result), indent=2)
|
||||
elif source_path is not None:
|
||||
docs = engine.list_documents()
|
||||
matches = [d for d in docs if d.get("source_path") == source_path]
|
||||
if not matches:
|
||||
return json.dumps({"error": "No document found with that source_path"})
|
||||
doc = engine.get_document(matches[0]["id"])
|
||||
return json.dumps(_process_document(doc), indent=2)
|
||||
else:
|
||||
return json.dumps({"error": "Provide either document_id or source_path"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_status() -> str:
|
||||
"""Get knowledge base engine status.
|
||||
|
||||
Returns engine version, embedding model info, device info, document counts,
|
||||
database size, and ingestion queue state.
|
||||
"""
|
||||
result = engine.get_status()
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_jobs(
|
||||
status: str | None = None,
|
||||
) -> str:
|
||||
"""List ingestion jobs and their status.
|
||||
|
||||
Returns recent jobs showing what has been queued, is processing, completed,
|
||||
or failed.
|
||||
|
||||
Args:
|
||||
status: Filter by job status ("queued", "processing", "done", "failed", "skipped").
|
||||
"""
|
||||
result = engine.list_jobs(status=status)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_upload_start(
|
||||
filename: str,
|
||||
total_size: int,
|
||||
tags: list[str] | None = None,
|
||||
collection: str | None = None,
|
||||
) -> str:
|
||||
"""Start a chunked file upload to the knowledge base.
|
||||
|
||||
Use this for uploading files from a remote agent. The upload process is:
|
||||
1. Call kb_upload_start to get an upload_id
|
||||
2. Call kb_upload_chunk repeatedly with base64-encoded file chunks (recommended ~1MB each)
|
||||
3. Call kb_upload_finish to submit the file for ingestion
|
||||
|
||||
Example for a 3MB file:
|
||||
upload = kb_upload_start(filename="report.pdf", total_size=3145728, collection="documents")
|
||||
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 0>", chunk_index=0)
|
||||
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 1>", chunk_index=1)
|
||||
kb_upload_chunk(upload_id=upload["upload_id"], data="<base64 chunk 2>", chunk_index=2)
|
||||
result = kb_upload_finish(upload_id=upload["upload_id"])
|
||||
|
||||
Args:
|
||||
filename: Original filename (used for type detection).
|
||||
total_size: Total file size in bytes.
|
||||
tags: Additional tags to apply.
|
||||
collection: Collection name (default "documents").
|
||||
"""
|
||||
all_tags = list(tags) if tags else []
|
||||
all_tags.append(_collection_tag(collection))
|
||||
|
||||
upload_id = uploads.start_upload(filename, total_size, all_tags)
|
||||
return json.dumps({"upload_id": upload_id})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_upload_chunk(
|
||||
upload_id: str,
|
||||
data: str,
|
||||
chunk_index: int,
|
||||
) -> str:
|
||||
"""Upload a base64-encoded chunk of a file.
|
||||
|
||||
Part of the chunked upload flow started by kb_upload_start.
|
||||
|
||||
Args:
|
||||
upload_id: The upload ID from kb_upload_start.
|
||||
data: Base64-encoded file data for this chunk.
|
||||
chunk_index: Zero-based index of this chunk.
|
||||
"""
|
||||
try:
|
||||
uploads.add_chunk(upload_id, data, chunk_index)
|
||||
return json.dumps({"status": "ok", "chunk_index": chunk_index})
|
||||
except KeyError as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_upload_finish(
|
||||
upload_id: str,
|
||||
) -> str:
|
||||
"""Finish a chunked upload and submit the file for ingestion.
|
||||
|
||||
Reassembles all uploaded chunks and forwards the complete file to the
|
||||
engine for processing. Returns the ingestion job ID.
|
||||
|
||||
Args:
|
||||
upload_id: The upload ID from kb_upload_start.
|
||||
"""
|
||||
try:
|
||||
filename, file_bytes, tags = uploads.finish_upload(upload_id)
|
||||
result = engine.upload_file(filename, file_bytes, tags)
|
||||
return json.dumps(result, indent=2)
|
||||
except KeyError as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BearerAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not config.KB_MCP_API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
if token == config.KB_MCP_API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Unauthorized"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ASGI app assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app():
|
||||
"""Create the ASGI app with auth middleware wrapping the MCP server."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
mcp_app = mcp.streamable_http_app()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
uploads.start_cleanup_task()
|
||||
logger.info("Upload cleanup task started")
|
||||
# Delegate to the MCP app's lifespan if it has one
|
||||
if hasattr(mcp_app, 'router') and hasattr(mcp_app.router, 'lifespan_context'):
|
||||
async with mcp_app.router.lifespan_context(app):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
app = Starlette(
|
||||
routes=[Mount("/", app=mcp_app)],
|
||||
middleware=[Middleware(BearerAuthMiddleware)],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logger.info(
|
||||
"Starting kb MCP server on port %d, engine=%s",
|
||||
config.KB_MCP_PORT,
|
||||
config.KB_ENGINE_URL,
|
||||
)
|
||||
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=config.KB_MCP_PORT)
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Chunked upload staging management."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger("kb.mcp.uploads")
|
||||
|
||||
UPLOAD_TIMEOUT_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagedUpload:
|
||||
upload_id: str
|
||||
filename: str
|
||||
total_size: int
|
||||
tags: list[str]
|
||||
staging_dir: Path
|
||||
created_at: float = field(default_factory=time.time)
|
||||
chunks: dict[int, Path] = field(default_factory=dict)
|
||||
|
||||
|
||||
_uploads: dict[str, StagedUpload] = {}
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def start_upload(filename: str, total_size: int, tags: list[str]) -> str:
|
||||
upload_id = str(uuid.uuid4())
|
||||
staging_dir = Path(tempfile.mkdtemp(prefix=f"kb_upload_{upload_id[:8]}_"))
|
||||
_uploads[upload_id] = StagedUpload(
|
||||
upload_id=upload_id,
|
||||
filename=filename,
|
||||
total_size=total_size,
|
||||
tags=tags,
|
||||
staging_dir=staging_dir,
|
||||
)
|
||||
logger.info("Started upload %s for %s (%d bytes)", upload_id, filename, total_size)
|
||||
return upload_id
|
||||
|
||||
|
||||
def add_chunk(upload_id: str, data_b64: str, chunk_index: int) -> None:
|
||||
upload = _uploads.get(upload_id)
|
||||
if upload is None:
|
||||
raise KeyError(f"Upload ID not found: {upload_id}")
|
||||
chunk_bytes = base64.b64decode(data_b64)
|
||||
chunk_path = upload.staging_dir / f"chunk_{chunk_index:06d}"
|
||||
chunk_path.write_bytes(chunk_bytes)
|
||||
upload.chunks[chunk_index] = chunk_path
|
||||
logger.info("Added chunk %d to upload %s (%d bytes)", chunk_index, upload_id, len(chunk_bytes))
|
||||
|
||||
|
||||
def finish_upload(upload_id: str) -> tuple[str, bytes, list[str]]:
|
||||
"""Reassemble chunks and return (filename, file_bytes, tags)."""
|
||||
upload = _uploads.get(upload_id)
|
||||
if upload is None:
|
||||
raise KeyError(f"Upload ID not found: {upload_id}")
|
||||
try:
|
||||
parts = []
|
||||
for idx in sorted(upload.chunks.keys()):
|
||||
parts.append(upload.chunks[idx].read_bytes())
|
||||
file_bytes = b"".join(parts)
|
||||
return upload.filename, file_bytes, upload.tags
|
||||
finally:
|
||||
_cleanup_upload(upload_id)
|
||||
|
||||
|
||||
def _cleanup_upload(upload_id: str) -> None:
|
||||
upload = _uploads.pop(upload_id, None)
|
||||
if upload and upload.staging_dir.exists():
|
||||
shutil.rmtree(upload.staging_dir, ignore_errors=True)
|
||||
|
||||
|
||||
async def cleanup_abandoned_uploads() -> None:
|
||||
"""Background task that removes uploads older than the timeout."""
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
now = time.time()
|
||||
expired = [
|
||||
uid for uid, u in _uploads.items()
|
||||
if now - u.created_at > UPLOAD_TIMEOUT_SECONDS
|
||||
]
|
||||
for uid in expired:
|
||||
logger.warning("Cleaning up abandoned upload %s", uid)
|
||||
_cleanup_upload(uid)
|
||||
|
||||
|
||||
def start_cleanup_task() -> None:
|
||||
global _cleanup_task
|
||||
if _cleanup_task is None or _cleanup_task.done():
|
||||
_cleanup_task = asyncio.create_task(cleanup_abandoned_uploads())
|
||||
Reference in New Issue
Block a user