store+server: P2-18a announce-and-approve schema + endpoint

migration 0011 adds pending_hosts table (id, hostname, public_key,
fingerprint, expiry). store/pending_hosts.go covers full CRUD plus
hostname-collision count + expired-row sweeper.

POST /api/agents/announce takes {hostname, os, arch, agent_version,
restic_version, public_key (base64)}, returns {pending_id,
fingerprint, hostname_collision}. Per-source-IP token-bucket
rate limit (10/min) + global cap of 100 in-flight rows. Public
key must be exactly 32 bytes (Ed25519).
This commit is contained in:
2026-05-04 11:03:41 +01:00
parent a5a2cb91d0
commit cd80be3b13
5 changed files with 654 additions and 1 deletions
+211
View File
@@ -0,0 +1,211 @@
// announce.go — POST /api/agents/announce: agent without a token
// announces itself with a freshly-minted Ed25519 public key, server
// stashes a pending_hosts row, admin compares fingerprints in the
// UI before accepting (P2-18a).
//
// Guards (per spec):
// - Per-source-IP token-bucket rate limit (10/min).
// - Global cap of 100 in-flight pending rows; further announces
// get 503 with a hint.
// - Public key must be exactly 32 bytes (Ed25519). Anything else
// 400-rejected.
//
// Hostname collisions are NOT rejected — multiple announces with
// the same hostname can be legitimate (re-running install on the
// same box). The UI flags collisions for the admin to disambiguate.
package http
import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
stdhttp "net/http"
"strings"
"sync"
"time"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// Tunables — exposed as vars so tests can lower them. Defaults mirror
// the spec's recommendations.
var (
announceMaxPerMin = 10
announceGlobalCap = 100
)
// announceRequest is the wire shape POST /api/agents/announce takes.
// PublicKey is base64-std (no padding strip — stdlib decoder is
// lenient on padding for both forms).
type announceRequest struct {
Hostname string `json:"hostname"`
OS string `json:"os"`
Arch string `json:"arch"`
AgentVersion string `json:"agent_version"`
ResticVersion string `json:"restic_version"`
PublicKey string `json:"public_key"` // base64
}
// announceResponse is what the agent gets back. Fingerprint is the
// canonical "SHA256:hex" the operator compares against the UI.
// HostnameCollision warns the install script that another pending
// row already uses the same hostname.
type announceResponse struct {
PendingID string `json:"pending_id"`
Fingerprint string `json:"fingerprint"`
HostnameCollision bool `json:"hostname_collision"`
}
// rateBucket is a tiny per-IP token-bucket. last is the timestamp of
// the most recent refill; tokens is the current bucket level. Refill
// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin.
type rateBucket struct {
tokens float64
last time.Time
}
// announceLimiter holds one bucket per source IP. Buckets are reaped
// lazily by a tiny grace period — we don't need true LRU cleanup
// because the bucket count is bounded by unique IPs in any given
// few minutes (small).
type announceLimiter struct {
mu sync.Mutex
buckets map[string]*rateBucket
}
func newAnnounceLimiter() *announceLimiter {
return &announceLimiter{buckets: map[string]*rateBucket{}}
}
// allow returns true and consumes a token if the IP's bucket has at
// least one token, else returns false. Capacity = announceMaxPerMin.
func (l *announceLimiter) allow(ip string, now time.Time) bool {
l.mu.Lock()
defer l.mu.Unlock()
cap := float64(announceMaxPerMin)
b, ok := l.buckets[ip]
if !ok {
b = &rateBucket{tokens: cap, last: now}
l.buckets[ip] = b
}
// Refill at cap tokens per minute.
elapsed := now.Sub(b.last).Seconds()
if elapsed > 0 {
b.tokens += (elapsed / 60.0) * cap
if b.tokens > cap {
b.tokens = cap
}
b.last = now
}
if b.tokens < 1.0 {
return false
}
b.tokens--
return true
}
// handleAnnounce is the public POST handler. Public — no auth.
func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) {
now := time.Now().UTC()
// Rate limit by source IP. Strip port — the limit is per host,
// not per outbound source port.
ip := remoteIP(r)
if !s.announceRL.allow(ip, now) {
w.Header().Set("Retry-After", "60")
writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited",
"too many announces from this source; retry in a minute")
return
}
var req announceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
return
}
if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" {
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
"hostname, os, arch, public_key are required")
return
}
keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
if err != nil {
// Try URL-safe / no-padding flavors before giving up.
if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil {
keyBytes = k2
} else {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
"public_key must be base64")
return
}
}
if len(keyBytes) != ed25519.PublicKeySize {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
"public_key must be 32 bytes (Ed25519)")
return
}
// Global cap (cheap query — index on expires_at).
count, err := s.deps.Store.CountPendingHosts(r.Context(), now)
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
if count >= announceGlobalCap {
writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached",
"too many in-flight pending hosts; ask an admin to clear the queue")
return
}
// Hostname collision flag (informational).
colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now)
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
ph := &store.PendingHost{
ID: ulid.Make().String(),
Hostname: req.Hostname,
OS: req.OS,
Arch: req.Arch,
AgentVersion: req.AgentVersion,
ResticVersion: req.ResticVersion,
PublicKey: keyBytes,
Fingerprint: store.FingerprintForKey(keyBytes),
AnnouncedFromIP: ip,
FirstSeenAt: now,
LastSeenAt: now,
ExpiresAt: now.Add(time.Hour),
}
if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
writeJSON(w, stdhttp.StatusOK, announceResponse{
PendingID: ph.ID,
Fingerprint: ph.Fingerprint,
HostnameCollision: colls > 0,
})
}
// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus
// the X-Forwarded-For chain's first hop when behind a trusted proxy
// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup
// matches the framework's existing behavior elsewhere.
func remoteIP(r *stdhttp.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the chain (closest to the original
// client) — same convention chi uses. Trim whitespace.
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
addr := r.RemoteAddr
if i := strings.LastIndex(addr, ":"); i >= 0 {
return addr[:i]
}
return addr
}
+165
View File
@@ -0,0 +1,165 @@
// announce_test.go — covers POST /api/agents/announce: happy path,
// invalid public key, hostname collision flag, rate limit, global
// cap (P2-18a).
package http
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
stdhttp "net/http"
"strings"
"testing"
"time"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
func newKeypair(t *testing.T) ed25519.PublicKey {
t.Helper()
pub, _, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("ed25519: %v", err)
}
return pub
}
func postAnnounce(t *testing.T, url string, req announceRequest) (status int, header stdhttp.Header, body []byte) {
t.Helper()
b, _ := json.Marshal(req)
r, _ := stdhttp.NewRequest("POST", url+"/api/agents/announce", bytes.NewReader(b))
r.Header.Set("Content-Type", "application/json")
res, err := stdhttp.DefaultClient.Do(r)
if err != nil {
t.Fatalf("do: %v", err)
}
defer res.Body.Close()
out := make([]byte, 4096)
n, _ := res.Body.Read(out)
return res.StatusCode, res.Header, out[:n]
}
func TestAnnounceHappyPath(t *testing.T) {
t.Parallel()
_, url, st := newTestServerWithHub(t)
pub := newKeypair(t)
status, _, body := postAnnounce(t, url, announceRequest{
Hostname: "alice", OS: "linux", Arch: "amd64",
AgentVersion: "1.0", ResticVersion: "0.17",
PublicKey: base64.StdEncoding.EncodeToString(pub),
})
if status != stdhttp.StatusOK {
t.Fatalf("status: %d body=%s", status, body)
}
var ar announceResponse
if err := json.Unmarshal(body, &ar); err != nil {
t.Fatalf("unmarshal: %v body=%s", err, body)
}
if ar.PendingID == "" {
t.Fatal("missing pending_id")
}
if !strings.HasPrefix(ar.Fingerprint, "SHA256:") {
t.Fatalf("fingerprint shape: %q", ar.Fingerprint)
}
if ar.HostnameCollision {
t.Fatal("first announce shouldn't be a collision")
}
// Row exists in the store.
if _, err := st.GetPendingHost(context.Background(), ar.PendingID); err != nil {
t.Fatalf("pending row missing: %v", err)
}
}
func TestAnnounceRejectsBadKey(t *testing.T) {
t.Parallel()
_, url, _ := newTestServerWithHub(t)
status, _, _ := postAnnounce(t, url, announceRequest{
Hostname: "x", OS: "linux", Arch: "amd64",
PublicKey: base64.StdEncoding.EncodeToString([]byte("too-short")),
})
if status != stdhttp.StatusBadRequest {
t.Fatalf("status: got %d, want 400", status)
}
}
func TestAnnounceHostnameCollisionFlag(t *testing.T) {
t.Parallel()
_, url, _ := newTestServerWithHub(t)
pub1 := newKeypair(t)
pub2 := newKeypair(t)
_, _, _ = postAnnounce(t, url, announceRequest{
Hostname: "dup-host", OS: "linux", Arch: "amd64",
PublicKey: base64.StdEncoding.EncodeToString(pub1),
})
status, _, body := postAnnounce(t, url, announceRequest{
Hostname: "dup-host", OS: "linux", Arch: "amd64",
PublicKey: base64.StdEncoding.EncodeToString(pub2),
})
if status != stdhttp.StatusOK {
t.Fatalf("status: %d", status)
}
var ar announceResponse
_ = json.Unmarshal(body, &ar)
if !ar.HostnameCollision {
t.Fatal("expected hostname_collision=true on second announce")
}
}
func TestAnnounceRateLimit(t *testing.T) {
t.Parallel()
_, url, _ := newTestServerWithHub(t)
// Lower the limit for the duration of this test (the limiter is
// per-server-instance so we don't disturb parallel tests).
prev := announceMaxPerMin
announceMaxPerMin = 2
t.Cleanup(func() { announceMaxPerMin = prev })
pub := newKeypair(t)
body := announceRequest{
Hostname: "rl-host", OS: "linux", Arch: "amd64",
PublicKey: base64.StdEncoding.EncodeToString(pub),
}
for i := 0; i < 2; i++ {
status, _, _ := postAnnounce(t, url, body)
if status != stdhttp.StatusOK {
t.Fatalf("call %d: status %d", i, status)
}
}
status, _, _ := postAnnounce(t, url, body)
if status != stdhttp.StatusTooManyRequests {
t.Fatalf("3rd call: want 429, got %d", status)
}
}
func TestAnnounceGlobalCap(t *testing.T) {
t.Parallel()
_, url, st := newTestServerWithHub(t)
prev := announceGlobalCap
announceGlobalCap = 1
t.Cleanup(func() { announceGlobalCap = prev })
// Pre-seed one row directly via the store so the cap is hit.
pub := newKeypair(t)
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
ID: ulid.Make().String(), Hostname: "x", OS: "linux", Arch: "amd64",
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
AnnouncedFromIP: "127.0.0.1",
FirstSeenAt: time.Now().UTC(),
LastSeenAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(time.Hour),
}); err != nil {
t.Fatalf("seed: %v", err)
}
status, _, _ := postAnnounce(t, url, announceRequest{
Hostname: "next", OS: "linux", Arch: "amd64",
PublicKey: base64.StdEncoding.EncodeToString(newKeypair(t)),
})
if status != stdhttp.StatusServiceUnavailable {
t.Fatalf("want 503 over cap, got %d", status)
}
}
+14 -1
View File
@@ -49,6 +49,10 @@ type Server struct {
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
drainLocksMu sync.Mutex
drainLocks map[string]*sync.Mutex
// announceRL is the per-source-IP token-bucket guarding
// POST /api/agents/announce (P2-18). One process-local map.
announceRL *announceLimiter
}
// New builds a configured but not-yet-started server.
@@ -67,7 +71,11 @@ func New(deps Deps) *Server {
w.WriteHeader(stdhttp.StatusNoContent)
})
s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)}
s := &Server{
deps: deps,
drainLocks: make(map[string]*sync.Mutex),
announceRL: newAnnounceLimiter(),
}
s.routes(r)
s.srv = &stdhttp.Server{
@@ -92,6 +100,11 @@ func (s *Server) routes(r chi.Router) {
// Agent enrollment (open endpoint — token is the credential).
r.Post("/agents/enroll", s.handleAgentEnroll)
// Announce-and-approve enrolment (open endpoint — fingerprint
// comparison in the UI is the gate). Per-IP rate-limited and
// globally capped (P2-18).
r.Post("/agents/announce", s.handleAnnounce)
// Operator → server (authenticated). Spec.md §6.1's
// /hosts/{id}/enrollment-token (regenerate) lands when the
// host page can call it; for now just the create endpoint.