From a8e6c9d6d7c2eab504b26582081f69ce2fd5c4f1 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Mon, 4 May 2026 11:03:41 +0100 Subject: [PATCH] store+server: P2-18a announce-and-approve schema + endpoint migration 0011 adds pending_hosts table (id, hostname, public_key, fingerprint, expiry). store/pending_hosts.go covers full CRUD plus hostname-collision count + expired-row sweeper. POST /api/agents/announce takes {hostname, os, arch, agent_version, restic_version, public_key (base64)}, returns {pending_id, fingerprint, hostname_collision}. Per-source-IP token-bucket rate limit (10/min) + global cap of 100 in-flight rows. Public key must be exactly 32 bytes (Ed25519). --- internal/server/http/announce.go | 211 ++++++++++++++++ internal/server/http/announce_test.go | 165 +++++++++++++ internal/server/http/server.go | 15 +- .../store/migrations/0011_pending_hosts.sql | 39 +++ internal/store/pending_hosts.go | 225 ++++++++++++++++++ 5 files changed, 654 insertions(+), 1 deletion(-) create mode 100644 internal/server/http/announce.go create mode 100644 internal/server/http/announce_test.go create mode 100644 internal/store/migrations/0011_pending_hosts.sql create mode 100644 internal/store/pending_hosts.go diff --git a/internal/server/http/announce.go b/internal/server/http/announce.go new file mode 100644 index 0000000..8635cb8 --- /dev/null +++ b/internal/server/http/announce.go @@ -0,0 +1,211 @@ +// announce.go — POST /api/agents/announce: agent without a token +// announces itself with a freshly-minted Ed25519 public key, server +// stashes a pending_hosts row, admin compares fingerprints in the +// UI before accepting (P2-18a). +// +// Guards (per spec): +// - Per-source-IP token-bucket rate limit (10/min). +// - Global cap of 100 in-flight pending rows; further announces +// get 503 with a hint. +// - Public key must be exactly 32 bytes (Ed25519). Anything else +// 400-rejected. +// +// Hostname collisions are NOT rejected — multiple announces with +// the same hostname can be legitimate (re-running install on the +// same box). The UI flags collisions for the admin to disambiguate. +package http + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json" + stdhttp "net/http" + "strings" + "sync" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// Tunables — exposed as vars so tests can lower them. Defaults mirror +// the spec's recommendations. +var ( + announceMaxPerMin = 10 + announceGlobalCap = 100 +) + +// announceRequest is the wire shape POST /api/agents/announce takes. +// PublicKey is base64-std (no padding strip — stdlib decoder is +// lenient on padding for both forms). +type announceRequest struct { + Hostname string `json:"hostname"` + OS string `json:"os"` + Arch string `json:"arch"` + AgentVersion string `json:"agent_version"` + ResticVersion string `json:"restic_version"` + PublicKey string `json:"public_key"` // base64 +} + +// announceResponse is what the agent gets back. Fingerprint is the +// canonical "SHA256:hex" the operator compares against the UI. +// HostnameCollision warns the install script that another pending +// row already uses the same hostname. +type announceResponse struct { + PendingID string `json:"pending_id"` + Fingerprint string `json:"fingerprint"` + HostnameCollision bool `json:"hostname_collision"` +} + +// rateBucket is a tiny per-IP token-bucket. last is the timestamp of +// the most recent refill; tokens is the current bucket level. Refill +// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin. +type rateBucket struct { + tokens float64 + last time.Time +} + +// announceLimiter holds one bucket per source IP. Buckets are reaped +// lazily by a tiny grace period — we don't need true LRU cleanup +// because the bucket count is bounded by unique IPs in any given +// few minutes (small). +type announceLimiter struct { + mu sync.Mutex + buckets map[string]*rateBucket +} + +func newAnnounceLimiter() *announceLimiter { + return &announceLimiter{buckets: map[string]*rateBucket{}} +} + +// allow returns true and consumes a token if the IP's bucket has at +// least one token, else returns false. Capacity = announceMaxPerMin. +func (l *announceLimiter) allow(ip string, now time.Time) bool { + l.mu.Lock() + defer l.mu.Unlock() + cap := float64(announceMaxPerMin) + b, ok := l.buckets[ip] + if !ok { + b = &rateBucket{tokens: cap, last: now} + l.buckets[ip] = b + } + // Refill at cap tokens per minute. + elapsed := now.Sub(b.last).Seconds() + if elapsed > 0 { + b.tokens += (elapsed / 60.0) * cap + if b.tokens > cap { + b.tokens = cap + } + b.last = now + } + if b.tokens < 1.0 { + return false + } + b.tokens-- + return true +} + +// handleAnnounce is the public POST handler. Public — no auth. +func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) { + now := time.Now().UTC() + + // Rate limit by source IP. Strip port — the limit is per host, + // not per outbound source port. + ip := remoteIP(r) + if !s.announceRL.allow(ip, now) { + w.Header().Set("Retry-After", "60") + writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited", + "too many announces from this source; retry in a minute") + return + } + + var req announceRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) + return + } + if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_field", + "hostname, os, arch, public_key are required") + return + } + + keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey) + if err != nil { + // Try URL-safe / no-padding flavors before giving up. + if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil { + keyBytes = k2 + } else { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key", + "public_key must be base64") + return + } + } + if len(keyBytes) != ed25519.PublicKeySize { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key", + "public_key must be 32 bytes (Ed25519)") + return + } + + // Global cap (cheap query — index on expires_at). + count, err := s.deps.Store.CountPendingHosts(r.Context(), now) + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) + return + } + if count >= announceGlobalCap { + writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached", + "too many in-flight pending hosts; ask an admin to clear the queue") + return + } + + // Hostname collision flag (informational). + colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now) + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) + return + } + + ph := &store.PendingHost{ + ID: ulid.Make().String(), + Hostname: req.Hostname, + OS: req.OS, + Arch: req.Arch, + AgentVersion: req.AgentVersion, + ResticVersion: req.ResticVersion, + PublicKey: keyBytes, + Fingerprint: store.FingerprintForKey(keyBytes), + AnnouncedFromIP: ip, + FirstSeenAt: now, + LastSeenAt: now, + ExpiresAt: now.Add(time.Hour), + } + if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error()) + return + } + writeJSON(w, stdhttp.StatusOK, announceResponse{ + PendingID: ph.ID, + Fingerprint: ph.Fingerprint, + HostnameCollision: colls > 0, + }) +} + +// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus +// the X-Forwarded-For chain's first hop when behind a trusted proxy +// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup +// matches the framework's existing behavior elsewhere. +func remoteIP(r *stdhttp.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the chain (closest to the original + // client) — same convention chi uses. Trim whitespace. + parts := strings.Split(xff, ",") + return strings.TrimSpace(parts[0]) + } + addr := r.RemoteAddr + if i := strings.LastIndex(addr, ":"); i >= 0 { + return addr[:i] + } + return addr +} diff --git a/internal/server/http/announce_test.go b/internal/server/http/announce_test.go new file mode 100644 index 0000000..097e500 --- /dev/null +++ b/internal/server/http/announce_test.go @@ -0,0 +1,165 @@ +// announce_test.go — covers POST /api/agents/announce: happy path, +// invalid public key, hostname collision flag, rate limit, global +// cap (P2-18a). +package http + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "encoding/json" + stdhttp "net/http" + "strings" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +func newKeypair(t *testing.T) ed25519.PublicKey { + t.Helper() + pub, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("ed25519: %v", err) + } + return pub +} + +func postAnnounce(t *testing.T, url string, req announceRequest) (status int, header stdhttp.Header, body []byte) { + t.Helper() + b, _ := json.Marshal(req) + r, _ := stdhttp.NewRequest("POST", url+"/api/agents/announce", bytes.NewReader(b)) + r.Header.Set("Content-Type", "application/json") + res, err := stdhttp.DefaultClient.Do(r) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + out := make([]byte, 4096) + n, _ := res.Body.Read(out) + return res.StatusCode, res.Header, out[:n] +} + +func TestAnnounceHappyPath(t *testing.T) { + t.Parallel() + _, url, st := newTestServerWithHub(t) + pub := newKeypair(t) + status, _, body := postAnnounce(t, url, announceRequest{ + Hostname: "alice", OS: "linux", Arch: "amd64", + AgentVersion: "1.0", ResticVersion: "0.17", + PublicKey: base64.StdEncoding.EncodeToString(pub), + }) + if status != stdhttp.StatusOK { + t.Fatalf("status: %d body=%s", status, body) + } + var ar announceResponse + if err := json.Unmarshal(body, &ar); err != nil { + t.Fatalf("unmarshal: %v body=%s", err, body) + } + if ar.PendingID == "" { + t.Fatal("missing pending_id") + } + if !strings.HasPrefix(ar.Fingerprint, "SHA256:") { + t.Fatalf("fingerprint shape: %q", ar.Fingerprint) + } + if ar.HostnameCollision { + t.Fatal("first announce shouldn't be a collision") + } + // Row exists in the store. + if _, err := st.GetPendingHost(context.Background(), ar.PendingID); err != nil { + t.Fatalf("pending row missing: %v", err) + } +} + +func TestAnnounceRejectsBadKey(t *testing.T) { + t.Parallel() + _, url, _ := newTestServerWithHub(t) + status, _, _ := postAnnounce(t, url, announceRequest{ + Hostname: "x", OS: "linux", Arch: "amd64", + PublicKey: base64.StdEncoding.EncodeToString([]byte("too-short")), + }) + if status != stdhttp.StatusBadRequest { + t.Fatalf("status: got %d, want 400", status) + } +} + +func TestAnnounceHostnameCollisionFlag(t *testing.T) { + t.Parallel() + _, url, _ := newTestServerWithHub(t) + pub1 := newKeypair(t) + pub2 := newKeypair(t) + _, _, _ = postAnnounce(t, url, announceRequest{ + Hostname: "dup-host", OS: "linux", Arch: "amd64", + PublicKey: base64.StdEncoding.EncodeToString(pub1), + }) + status, _, body := postAnnounce(t, url, announceRequest{ + Hostname: "dup-host", OS: "linux", Arch: "amd64", + PublicKey: base64.StdEncoding.EncodeToString(pub2), + }) + if status != stdhttp.StatusOK { + t.Fatalf("status: %d", status) + } + var ar announceResponse + _ = json.Unmarshal(body, &ar) + if !ar.HostnameCollision { + t.Fatal("expected hostname_collision=true on second announce") + } +} + +func TestAnnounceRateLimit(t *testing.T) { + t.Parallel() + _, url, _ := newTestServerWithHub(t) + // Lower the limit for the duration of this test (the limiter is + // per-server-instance so we don't disturb parallel tests). + prev := announceMaxPerMin + announceMaxPerMin = 2 + t.Cleanup(func() { announceMaxPerMin = prev }) + + pub := newKeypair(t) + body := announceRequest{ + Hostname: "rl-host", OS: "linux", Arch: "amd64", + PublicKey: base64.StdEncoding.EncodeToString(pub), + } + for i := 0; i < 2; i++ { + status, _, _ := postAnnounce(t, url, body) + if status != stdhttp.StatusOK { + t.Fatalf("call %d: status %d", i, status) + } + } + status, _, _ := postAnnounce(t, url, body) + if status != stdhttp.StatusTooManyRequests { + t.Fatalf("3rd call: want 429, got %d", status) + } +} + +func TestAnnounceGlobalCap(t *testing.T) { + t.Parallel() + _, url, st := newTestServerWithHub(t) + prev := announceGlobalCap + announceGlobalCap = 1 + t.Cleanup(func() { announceGlobalCap = prev }) + + // Pre-seed one row directly via the store so the cap is hit. + pub := newKeypair(t) + if err := st.CreatePendingHost(context.Background(), &store.PendingHost{ + ID: ulid.Make().String(), Hostname: "x", OS: "linux", Arch: "amd64", + PublicKey: pub, Fingerprint: store.FingerprintForKey(pub), + AnnouncedFromIP: "127.0.0.1", + FirstSeenAt: time.Now().UTC(), + LastSeenAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Hour), + }); err != nil { + t.Fatalf("seed: %v", err) + } + status, _, _ := postAnnounce(t, url, announceRequest{ + Hostname: "next", OS: "linux", Arch: "amd64", + PublicKey: base64.StdEncoding.EncodeToString(newKeypair(t)), + }) + if status != stdhttp.StatusServiceUnavailable { + t.Fatalf("want 503 over cap, got %d", status) + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index c232407..bef412c 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -49,6 +49,10 @@ type Server struct { // sync.Mutex; checked-and-locked atomically via drainLocksMu. drainLocksMu sync.Mutex drainLocks map[string]*sync.Mutex + + // announceRL is the per-source-IP token-bucket guarding + // POST /api/agents/announce (P2-18). One process-local map. + announceRL *announceLimiter } // New builds a configured but not-yet-started server. @@ -67,7 +71,11 @@ func New(deps Deps) *Server { w.WriteHeader(stdhttp.StatusNoContent) }) - s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)} + s := &Server{ + deps: deps, + drainLocks: make(map[string]*sync.Mutex), + announceRL: newAnnounceLimiter(), + } s.routes(r) s.srv = &stdhttp.Server{ @@ -92,6 +100,11 @@ func (s *Server) routes(r chi.Router) { // Agent enrollment (open endpoint — token is the credential). r.Post("/agents/enroll", s.handleAgentEnroll) + // Announce-and-approve enrolment (open endpoint — fingerprint + // comparison in the UI is the gate). Per-IP rate-limited and + // globally capped (P2-18). + r.Post("/agents/announce", s.handleAnnounce) + // Operator → server (authenticated). Spec.md §6.1's // /hosts/{id}/enrollment-token (regenerate) lands when the // host page can call it; for now just the create endpoint. diff --git a/internal/store/migrations/0011_pending_hosts.sql b/internal/store/migrations/0011_pending_hosts.sql new file mode 100644 index 0000000..61184f2 --- /dev/null +++ b/internal/store/migrations/0011_pending_hosts.sql @@ -0,0 +1,39 @@ +-- 0011_pending_hosts.sql +-- +-- P2-18: announce-and-approve enrolment. +-- +-- Agents that don't have an enrolment token announce themselves +-- with `POST /api/agents/announce`, persisting one row here. The +-- admin sees them in the dashboard's Pending hosts panel and can +-- accept (mints a real Host row + bearer) or reject (deletes the +-- row + closes the agent's pending WS). +-- +-- public_key is the agent's Ed25519 public key (32 raw bytes). +-- fingerprint = "SHA256:" + hex(sha256(public_key)) — printed by +-- the install script on the endpoint terminal so the operator can +-- compare the two before clicking accept. This comparison is the +-- load-bearing security gate for this flow. +-- +-- expires_at is set to first_seen_at + 1h on insert; a sweeper +-- goroutine (P2-18b) deletes rows past their expiry. Hostname +-- collisions with existing or other pending rows are *not* +-- prevented at the DB level — multiple announces with the same +-- hostname are flagged in the UI so admin can pick the right one. + +CREATE TABLE pending_hosts ( + id TEXT PRIMARY KEY, + hostname TEXT NOT NULL, + os TEXT NOT NULL, + arch TEXT NOT NULL, + agent_version TEXT NOT NULL, + restic_version TEXT NOT NULL, + public_key BLOB NOT NULL, -- 32-byte Ed25519 + fingerprint TEXT NOT NULL, -- "SHA256:hex(...)" + announced_from_ip TEXT NOT NULL, + first_seen_at TEXT NOT NULL, + last_seen_at TEXT NOT NULL, + expires_at TEXT NOT NULL +); +CREATE INDEX pending_hosts_expires ON pending_hosts(expires_at); +CREATE INDEX pending_hosts_fingerprint ON pending_hosts(fingerprint); +CREATE INDEX pending_hosts_hostname ON pending_hosts(hostname); diff --git a/internal/store/pending_hosts.go b/internal/store/pending_hosts.go new file mode 100644 index 0000000..c16d8a1 --- /dev/null +++ b/internal/store/pending_hosts.go @@ -0,0 +1,225 @@ +// pending_hosts.go — store layer for the announce-and-approve +// enrolment queue (P2-18a). Rows live for at most 1h; a sweeper +// deletes anything past expires_at. +package store + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "time" +) + +// PendingHost mirrors the pending_hosts table row, plus the derived +// HostnameCollision flag the API hands back to the agent so the +// install script can warn the operator at announce time. +type PendingHost struct { + ID string + Hostname string + OS string + Arch string + AgentVersion string + ResticVersion string + PublicKey []byte // 32-byte Ed25519 + Fingerprint string // "SHA256:hex" + AnnouncedFromIP string + FirstSeenAt time.Time + LastSeenAt time.Time + ExpiresAt time.Time +} + +// FingerprintForKey returns the canonical "SHA256:hex" fingerprint +// the operator sees in the UI and on the endpoint terminal. +func FingerprintForKey(pubKey []byte) string { + sum := sha256.Sum256(pubKey) + return "SHA256:" + hex.EncodeToString(sum[:]) +} + +// CreatePendingHost inserts a new row. Caller has already validated +// the public key length and rate limits. +func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error { + if ph.ID == "" || len(ph.PublicKey) == 0 { + return errors.New("store: pending host id + public_key required") + } + if ph.Fingerprint == "" { + ph.Fingerprint = FingerprintForKey(ph.PublicKey) + } + now := time.Now().UTC() + if ph.FirstSeenAt.IsZero() { + ph.FirstSeenAt = now + } + ph.LastSeenAt = now + if ph.ExpiresAt.IsZero() { + ph.ExpiresAt = now.Add(time.Hour) + } + _, err := s.db.ExecContext(ctx, + `INSERT INTO pending_hosts ( + id, hostname, os, arch, agent_version, restic_version, + public_key, fingerprint, announced_from_ip, + first_seen_at, last_seen_at, expires_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion, + ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP, + ph.FirstSeenAt.Format(time.RFC3339Nano), + ph.LastSeenAt.Format(time.RFC3339Nano), + ph.ExpiresAt.Format(time.RFC3339Nano), + ) + if err != nil { + return fmt.Errorf("store: create pending host: %w", err) + } + return nil +} + +// TouchPendingHost bumps last_seen_at on the named pending row, +// extending its visibility in the dashboard while the agent's +// pending WS stays open. Does NOT extend expires_at — the 1h cap +// is firm. +func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error { + _, err := s.db.ExecContext(ctx, + `UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`, + when.UTC().Format(time.RFC3339Nano), id) + return err +} + +// GetPendingHost returns one row by ID. ErrNotFound on miss. +func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, hostname, os, arch, agent_version, restic_version, + public_key, fingerprint, announced_from_ip, + first_seen_at, last_seen_at, expires_at + FROM pending_hosts WHERE id = ?`, id) + return scanPendingHost(row) +} + +// GetPendingHostByFingerprint resolves a row by its public key +// fingerprint (used by the WS pending handler to look up which row +// an incoming connection corresponds to). +func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, hostname, os, arch, agent_version, restic_version, + public_key, fingerprint, announced_from_ip, + first_seen_at, last_seen_at, expires_at + FROM pending_hosts WHERE fingerprint = ?`, fp) + return scanPendingHost(row) +} + +// ListPendingHosts returns every non-expired row, newest first. The +// caller passes `now` so tests can fast-forward. +func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, hostname, os, arch, agent_version, restic_version, + public_key, fingerprint, announced_from_ip, + first_seen_at, last_seen_at, expires_at + FROM pending_hosts WHERE expires_at > ? + ORDER BY first_seen_at DESC`, + now.UTC().Format(time.RFC3339Nano)) + if err != nil { + return nil, fmt.Errorf("store: list pending hosts: %w", err) + } + defer func() { _ = rows.Close() }() + out := []PendingHost{} + for rows.Next() { + ph, err := scanPendingHostRow(rows) + if err != nil { + return nil, err + } + out = append(out, *ph) + } + return out, rows.Err() +} + +// CountPendingHosts returns the count of non-expired rows. Used for +// the global cap (P2-18: refuse new announces past 100 in flight). +func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) { + var n int + err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`, + now.UTC().Format(time.RFC3339Nano)).Scan(&n) + if err != nil { + return 0, fmt.Errorf("store: count pending hosts: %w", err) + } + return n, nil +} + +// CountPendingHostsByHostname returns the number of non-expired +// pending rows that share the supplied hostname. Used by the +// announce endpoint to set the hostname_collision flag in its +// response. +func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) { + var n int + err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`, + hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n) + if err != nil { + return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err) + } + return n, nil +} + +// DeletePendingHost removes one row by ID. ErrNotFound on miss. +func (s *Store) DeletePendingHost(ctx context.Context, id string) error { + res, err := s.db.ExecContext(ctx, + `DELETE FROM pending_hosts WHERE id = ?`, id) + if err != nil { + return fmt.Errorf("store: delete pending host: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return ErrNotFound + } + return nil +} + +// DeleteExpiredPendingHosts removes every row whose expires_at is in +// the past. Returns the number of rows deleted so the sweeper can +// log non-zero events. +func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) { + res, err := s.db.ExecContext(ctx, + `DELETE FROM pending_hosts WHERE expires_at <= ?`, + now.UTC().Format(time.RFC3339Nano)) + if err != nil { + return 0, fmt.Errorf("store: delete expired pending hosts: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} + +// ----- scan helpers -------------------------------------------------- + +type pendingHostScanner interface { + Scan(dest ...any) error +} + +func scanPendingHost(row *sql.Row) (*PendingHost, error) { + ph, err := scanPendingHostRow(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return ph, err +} + +func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) { + var ( + ph PendingHost + firstSeenAt, lastSeenAt, expiresAt string + ) + if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch, + &ph.AgentVersion, &ph.ResticVersion, + &ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP, + &firstSeenAt, &lastSeenAt, &expiresAt); err != nil { + return nil, err + } + if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil { + ph.FirstSeenAt = t + } + if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil { + ph.LastSeenAt = t + } + if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil { + ph.ExpiresAt = t + } + return &ph, nil +}