store+server: P2-18a announce-and-approve schema + endpoint
migration 0011 adds pending_hosts table (id, hostname, public_key,
fingerprint, expiry). store/pending_hosts.go covers full CRUD plus
hostname-collision count + expired-row sweeper.
POST /api/agents/announce takes {hostname, os, arch, agent_version,
restic_version, public_key (base64)}, returns {pending_id,
fingerprint, hostname_collision}. Per-source-IP token-bucket
rate limit (10/min) + global cap of 100 in-flight rows. Public
key must be exactly 32 bytes (Ed25519).
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
// announce.go — POST /api/agents/announce: agent without a token
|
||||
// announces itself with a freshly-minted Ed25519 public key, server
|
||||
// stashes a pending_hosts row, admin compares fingerprints in the
|
||||
// UI before accepting (P2-18a).
|
||||
//
|
||||
// Guards (per spec):
|
||||
// - Per-source-IP token-bucket rate limit (10/min).
|
||||
// - Global cap of 100 in-flight pending rows; further announces
|
||||
// get 503 with a hint.
|
||||
// - Public key must be exactly 32 bytes (Ed25519). Anything else
|
||||
// 400-rejected.
|
||||
//
|
||||
// Hostname collisions are NOT rejected — multiple announces with
|
||||
// the same hostname can be legitimate (re-running install on the
|
||||
// same box). The UI flags collisions for the admin to disambiguate.
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Tunables — exposed as vars so tests can lower them. Defaults mirror
|
||||
// the spec's recommendations.
|
||||
var (
|
||||
announceMaxPerMin = 10
|
||||
announceGlobalCap = 100
|
||||
)
|
||||
|
||||
// announceRequest is the wire shape POST /api/agents/announce takes.
|
||||
// PublicKey is base64-std (no padding strip — stdlib decoder is
|
||||
// lenient on padding for both forms).
|
||||
type announceRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
PublicKey string `json:"public_key"` // base64
|
||||
}
|
||||
|
||||
// announceResponse is what the agent gets back. Fingerprint is the
|
||||
// canonical "SHA256:hex" the operator compares against the UI.
|
||||
// HostnameCollision warns the install script that another pending
|
||||
// row already uses the same hostname.
|
||||
type announceResponse struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
HostnameCollision bool `json:"hostname_collision"`
|
||||
}
|
||||
|
||||
// rateBucket is a tiny per-IP token-bucket. last is the timestamp of
|
||||
// the most recent refill; tokens is the current bucket level. Refill
|
||||
// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin.
|
||||
type rateBucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
// announceLimiter holds one bucket per source IP. Buckets are reaped
|
||||
// lazily by a tiny grace period — we don't need true LRU cleanup
|
||||
// because the bucket count is bounded by unique IPs in any given
|
||||
// few minutes (small).
|
||||
type announceLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*rateBucket
|
||||
}
|
||||
|
||||
func newAnnounceLimiter() *announceLimiter {
|
||||
return &announceLimiter{buckets: map[string]*rateBucket{}}
|
||||
}
|
||||
|
||||
// allow returns true and consumes a token if the IP's bucket has at
|
||||
// least one token, else returns false. Capacity = announceMaxPerMin.
|
||||
func (l *announceLimiter) allow(ip string, now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
cap := float64(announceMaxPerMin)
|
||||
b, ok := l.buckets[ip]
|
||||
if !ok {
|
||||
b = &rateBucket{tokens: cap, last: now}
|
||||
l.buckets[ip] = b
|
||||
}
|
||||
// Refill at cap tokens per minute.
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
if elapsed > 0 {
|
||||
b.tokens += (elapsed / 60.0) * cap
|
||||
if b.tokens > cap {
|
||||
b.tokens = cap
|
||||
}
|
||||
b.last = now
|
||||
}
|
||||
if b.tokens < 1.0 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAnnounce is the public POST handler. Public — no auth.
|
||||
func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Rate limit by source IP. Strip port — the limit is per host,
|
||||
// not per outbound source port.
|
||||
ip := remoteIP(r)
|
||||
if !s.announceRL.allow(ip, now) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited",
|
||||
"too many announces from this source; retry in a minute")
|
||||
return
|
||||
}
|
||||
|
||||
var req announceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||
"hostname, os, arch, public_key are required")
|
||||
return
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
|
||||
if err != nil {
|
||||
// Try URL-safe / no-padding flavors before giving up.
|
||||
if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil {
|
||||
keyBytes = k2
|
||||
} else {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be base64")
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(keyBytes) != ed25519.PublicKeySize {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be 32 bytes (Ed25519)")
|
||||
return
|
||||
}
|
||||
|
||||
// Global cap (cheap query — index on expires_at).
|
||||
count, err := s.deps.Store.CountPendingHosts(r.Context(), now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if count >= announceGlobalCap {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached",
|
||||
"too many in-flight pending hosts; ask an admin to clear the queue")
|
||||
return
|
||||
}
|
||||
|
||||
// Hostname collision flag (informational).
|
||||
colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ph := &store.PendingHost{
|
||||
ID: ulid.Make().String(),
|
||||
Hostname: req.Hostname,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
AgentVersion: req.AgentVersion,
|
||||
ResticVersion: req.ResticVersion,
|
||||
PublicKey: keyBytes,
|
||||
Fingerprint: store.FingerprintForKey(keyBytes),
|
||||
AnnouncedFromIP: ip,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, stdhttp.StatusOK, announceResponse{
|
||||
PendingID: ph.ID,
|
||||
Fingerprint: ph.Fingerprint,
|
||||
HostnameCollision: colls > 0,
|
||||
})
|
||||
}
|
||||
|
||||
// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus
|
||||
// the X-Forwarded-For chain's first hop when behind a trusted proxy
|
||||
// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup
|
||||
// matches the framework's existing behavior elsewhere.
|
||||
func remoteIP(r *stdhttp.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the chain (closest to the original
|
||||
// client) — same convention chi uses. Trim whitespace.
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
addr := r.RemoteAddr
|
||||
if i := strings.LastIndex(addr, ":"); i >= 0 {
|
||||
return addr[:i]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
Reference in New Issue
Block a user