// announce.go — POST /api/agents/announce: agent without a token // announces itself with a freshly-minted Ed25519 public key, server // stashes a pending_hosts row, admin compares fingerprints in the // UI before accepting (P2-18a). // // Guards (per spec): // - Per-source-IP token-bucket rate limit (10/min). // - Global cap of 100 in-flight pending rows; further announces // get 503 with a hint. // - Public key must be exactly 32 bytes (Ed25519). Anything else // 400-rejected. // // Hostname collisions are NOT rejected — multiple announces with // the same hostname can be legitimate (re-running install on the // same box). The UI flags collisions for the admin to disambiguate. package http import ( "crypto/ed25519" "encoding/base64" "encoding/json" stdhttp "net/http" "strings" "sync" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // Tunables — exposed as vars so tests can lower them. Defaults mirror // the spec's recommendations. var ( announceMaxPerMin = 10 announceGlobalCap = 100 ) // announceRequest is the wire shape POST /api/agents/announce takes. // PublicKey is base64-std (no padding strip — stdlib decoder is // lenient on padding for both forms). type announceRequest struct { Hostname string `json:"hostname"` OS string `json:"os"` Arch string `json:"arch"` AgentVersion string `json:"agent_version"` ResticVersion string `json:"restic_version"` PublicKey string `json:"public_key"` // base64 } // announceResponse is what the agent gets back. Fingerprint is the // canonical "SHA256:hex" the operator compares against the UI. // HostnameCollision warns the install script that another pending // row already uses the same hostname. type announceResponse struct { PendingID string `json:"pending_id"` Fingerprint string `json:"fingerprint"` HostnameCollision bool `json:"hostname_collision"` } // rateBucket is a tiny per-IP token-bucket. last is the timestamp of // the most recent refill; tokens is the current bucket level. Refill // rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin. type rateBucket struct { tokens float64 last time.Time } // announceLimiter holds one bucket per source IP. Buckets are reaped // lazily by a tiny grace period — we don't need true LRU cleanup // because the bucket count is bounded by unique IPs in any given // few minutes (small). type announceLimiter struct { mu sync.Mutex buckets map[string]*rateBucket } func newAnnounceLimiter() *announceLimiter { return &announceLimiter{buckets: map[string]*rateBucket{}} } // allow returns true and consumes a token if the IP's bucket has at // least one token, else returns false. Capacity = announceMaxPerMin. func (l *announceLimiter) allow(ip string, now time.Time) bool { l.mu.Lock() defer l.mu.Unlock() cap := float64(announceMaxPerMin) b, ok := l.buckets[ip] if !ok { b = &rateBucket{tokens: cap, last: now} l.buckets[ip] = b } // Refill at cap tokens per minute. elapsed := now.Sub(b.last).Seconds() if elapsed > 0 { b.tokens += (elapsed / 60.0) * cap if b.tokens > cap { b.tokens = cap } b.last = now } if b.tokens < 1.0 { return false } b.tokens-- return true } // handleAnnounce is the public POST handler. Public — no auth. func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) { now := time.Now().UTC() // Rate limit by source IP. Strip port — the limit is per host, // not per outbound source port. ip := remoteIP(r) if !s.announceRL.allow(ip, now) { w.Header().Set("Retry-After", "60") writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited", "too many announces from this source; retry in a minute") return } var req announceRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) return } if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" { writeJSONError(w, stdhttp.StatusBadRequest, "missing_field", "hostname, os, arch, public_key are required") return } keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey) if err != nil { // Try URL-safe / no-padding flavors before giving up. if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil { keyBytes = k2 } else { writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key", "public_key must be base64") return } } if len(keyBytes) != ed25519.PublicKeySize { writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key", "public_key must be 32 bytes (Ed25519)") return } // Global cap (cheap query — index on expires_at). count, err := s.deps.Store.CountPendingHosts(r.Context(), now) if err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) return } if count >= announceGlobalCap { writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached", "too many in-flight pending hosts; ask an admin to clear the queue") return } // Hostname collision flag (informational). colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now) if err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) return } ph := &store.PendingHost{ ID: ulid.Make().String(), Hostname: req.Hostname, OS: req.OS, Arch: req.Arch, AgentVersion: req.AgentVersion, ResticVersion: req.ResticVersion, PublicKey: keyBytes, Fingerprint: store.FingerprintForKey(keyBytes), AnnouncedFromIP: ip, FirstSeenAt: now, LastSeenAt: now, ExpiresAt: now.Add(time.Hour), } if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) return } writeJSON(w, stdhttp.StatusOK, announceResponse{ PendingID: ph.ID, Fingerprint: ph.Fingerprint, HostnameCollision: colls > 0, }) } // remoteIP returns r.RemoteAddr stripped of any :port suffix, plus // the X-Forwarded-For chain's first hop when behind a trusted proxy // (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup // matches the framework's existing behavior elsewhere. func remoteIP(r *stdhttp.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain (closest to the original // client) — same convention chi uses. Trim whitespace. parts := strings.Split(xff, ",") return strings.TrimSpace(parts[0]) } addr := r.RemoteAddr if i := strings.LastIndex(addr, ":"); i >= 0 { return addr[:i] } return addr }