"""Chunked upload staging management.""" import asyncio import base64 import logging import shutil import tempfile import time import uuid from dataclasses import dataclass, field from pathlib import Path logger = logging.getLogger("kb.mcp.uploads") UPLOAD_TIMEOUT_SECONDS = 600 # 10 minutes @dataclass class StagedUpload: upload_id: str filename: str total_size: int tags: list[str] staging_dir: Path created_at: float = field(default_factory=time.time) chunks: dict[int, Path] = field(default_factory=dict) _uploads: dict[str, StagedUpload] = {} _cleanup_task: asyncio.Task | None = None def start_upload(filename: str, total_size: int, tags: list[str]) -> str: upload_id = str(uuid.uuid4()) staging_dir = Path(tempfile.mkdtemp(prefix=f"kb_upload_{upload_id[:8]}_")) _uploads[upload_id] = StagedUpload( upload_id=upload_id, filename=filename, total_size=total_size, tags=tags, staging_dir=staging_dir, ) logger.info("Started upload %s for %s (%d bytes)", upload_id, filename, total_size) return upload_id def add_chunk(upload_id: str, data_b64: str, chunk_index: int) -> None: upload = _uploads.get(upload_id) if upload is None: raise KeyError(f"Upload ID not found: {upload_id}") chunk_bytes = base64.b64decode(data_b64) chunk_path = upload.staging_dir / f"chunk_{chunk_index:06d}" chunk_path.write_bytes(chunk_bytes) upload.chunks[chunk_index] = chunk_path logger.info("Added chunk %d to upload %s (%d bytes)", chunk_index, upload_id, len(chunk_bytes)) def finish_upload(upload_id: str) -> tuple[str, bytes, list[str]]: """Reassemble chunks and return (filename, file_bytes, tags).""" upload = _uploads.get(upload_id) if upload is None: raise KeyError(f"Upload ID not found: {upload_id}") try: parts = [] for idx in sorted(upload.chunks.keys()): parts.append(upload.chunks[idx].read_bytes()) file_bytes = b"".join(parts) return upload.filename, file_bytes, upload.tags finally: _cleanup_upload(upload_id) def _cleanup_upload(upload_id: str) -> None: upload = _uploads.pop(upload_id, None) if upload and upload.staging_dir.exists(): shutil.rmtree(upload.staging_dir, ignore_errors=True) async def cleanup_abandoned_uploads() -> None: """Background task that removes uploads older than the timeout.""" while True: await asyncio.sleep(60) now = time.time() expired = [ uid for uid, u in _uploads.items() if now - u.created_at > UPLOAD_TIMEOUT_SECONDS ] for uid in expired: logger.warning("Cleaning up abandoned upload %s", uid) _cleanup_upload(uid) def start_cleanup_task() -> None: global _cleanup_task if _cleanup_task is None or _cleanup_task.done(): _cleanup_task = asyncio.create_task(cleanup_abandoned_uploads())