P2 completion (P2R-09/10/11/12/13/14, P2-16/17/18) #5
@@ -0,0 +1,211 @@
|
||||
// announce.go — POST /api/agents/announce: agent without a token
|
||||
// announces itself with a freshly-minted Ed25519 public key, server
|
||||
// stashes a pending_hosts row, admin compares fingerprints in the
|
||||
// UI before accepting (P2-18a).
|
||||
//
|
||||
// Guards (per spec):
|
||||
// - Per-source-IP token-bucket rate limit (10/min).
|
||||
// - Global cap of 100 in-flight pending rows; further announces
|
||||
// get 503 with a hint.
|
||||
// - Public key must be exactly 32 bytes (Ed25519). Anything else
|
||||
// 400-rejected.
|
||||
//
|
||||
// Hostname collisions are NOT rejected — multiple announces with
|
||||
// the same hostname can be legitimate (re-running install on the
|
||||
// same box). The UI flags collisions for the admin to disambiguate.
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Tunables — exposed as vars so tests can lower them. Defaults mirror
|
||||
// the spec's recommendations.
|
||||
var (
|
||||
announceMaxPerMin = 10
|
||||
announceGlobalCap = 100
|
||||
)
|
||||
|
||||
// announceRequest is the wire shape POST /api/agents/announce takes.
|
||||
// PublicKey is base64-std (no padding strip — stdlib decoder is
|
||||
// lenient on padding for both forms).
|
||||
type announceRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
PublicKey string `json:"public_key"` // base64
|
||||
}
|
||||
|
||||
// announceResponse is what the agent gets back. Fingerprint is the
|
||||
// canonical "SHA256:hex" the operator compares against the UI.
|
||||
// HostnameCollision warns the install script that another pending
|
||||
// row already uses the same hostname.
|
||||
type announceResponse struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
HostnameCollision bool `json:"hostname_collision"`
|
||||
}
|
||||
|
||||
// rateBucket is a tiny per-IP token-bucket. last is the timestamp of
|
||||
// the most recent refill; tokens is the current bucket level. Refill
|
||||
// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin.
|
||||
type rateBucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
// announceLimiter holds one bucket per source IP. Buckets are reaped
|
||||
// lazily by a tiny grace period — we don't need true LRU cleanup
|
||||
// because the bucket count is bounded by unique IPs in any given
|
||||
// few minutes (small).
|
||||
type announceLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*rateBucket
|
||||
}
|
||||
|
||||
func newAnnounceLimiter() *announceLimiter {
|
||||
return &announceLimiter{buckets: map[string]*rateBucket{}}
|
||||
}
|
||||
|
||||
// allow returns true and consumes a token if the IP's bucket has at
|
||||
// least one token, else returns false. Capacity = announceMaxPerMin.
|
||||
func (l *announceLimiter) allow(ip string, now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
cap := float64(announceMaxPerMin)
|
||||
b, ok := l.buckets[ip]
|
||||
if !ok {
|
||||
b = &rateBucket{tokens: cap, last: now}
|
||||
l.buckets[ip] = b
|
||||
}
|
||||
// Refill at cap tokens per minute.
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
if elapsed > 0 {
|
||||
b.tokens += (elapsed / 60.0) * cap
|
||||
if b.tokens > cap {
|
||||
b.tokens = cap
|
||||
}
|
||||
b.last = now
|
||||
}
|
||||
if b.tokens < 1.0 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAnnounce is the public POST handler. Public — no auth.
|
||||
func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Rate limit by source IP. Strip port — the limit is per host,
|
||||
// not per outbound source port.
|
||||
ip := remoteIP(r)
|
||||
if !s.announceRL.allow(ip, now) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited",
|
||||
"too many announces from this source; retry in a minute")
|
||||
return
|
||||
}
|
||||
|
||||
var req announceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||
"hostname, os, arch, public_key are required")
|
||||
return
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
|
||||
if err != nil {
|
||||
// Try URL-safe / no-padding flavors before giving up.
|
||||
if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil {
|
||||
keyBytes = k2
|
||||
} else {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be base64")
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(keyBytes) != ed25519.PublicKeySize {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be 32 bytes (Ed25519)")
|
||||
return
|
||||
}
|
||||
|
||||
// Global cap (cheap query — index on expires_at).
|
||||
count, err := s.deps.Store.CountPendingHosts(r.Context(), now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if count >= announceGlobalCap {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached",
|
||||
"too many in-flight pending hosts; ask an admin to clear the queue")
|
||||
return
|
||||
}
|
||||
|
||||
// Hostname collision flag (informational).
|
||||
colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ph := &store.PendingHost{
|
||||
ID: ulid.Make().String(),
|
||||
Hostname: req.Hostname,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
AgentVersion: req.AgentVersion,
|
||||
ResticVersion: req.ResticVersion,
|
||||
PublicKey: keyBytes,
|
||||
Fingerprint: store.FingerprintForKey(keyBytes),
|
||||
AnnouncedFromIP: ip,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, stdhttp.StatusOK, announceResponse{
|
||||
PendingID: ph.ID,
|
||||
Fingerprint: ph.Fingerprint,
|
||||
HostnameCollision: colls > 0,
|
||||
})
|
||||
}
|
||||
|
||||
// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus
|
||||
// the X-Forwarded-For chain's first hop when behind a trusted proxy
|
||||
// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup
|
||||
// matches the framework's existing behavior elsewhere.
|
||||
func remoteIP(r *stdhttp.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the chain (closest to the original
|
||||
// client) — same convention chi uses. Trim whitespace.
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
addr := r.RemoteAddr
|
||||
if i := strings.LastIndex(addr, ":"); i >= 0 {
|
||||
return addr[:i]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
// announce_test.go — covers POST /api/agents/announce: happy path,
|
||||
// invalid public key, hostname collision flag, rate limit, global
|
||||
// cap (P2-18a).
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func newKeypair(t *testing.T) ed25519.PublicKey {
|
||||
t.Helper()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ed25519: %v", err)
|
||||
}
|
||||
return pub
|
||||
}
|
||||
|
||||
func postAnnounce(t *testing.T, url string, req announceRequest) (status int, header stdhttp.Header, body []byte) {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(req)
|
||||
r, _ := stdhttp.NewRequest("POST", url+"/api/agents/announce", bytes.NewReader(b))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
res, err := stdhttp.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
out := make([]byte, 4096)
|
||||
n, _ := res.Body.Read(out)
|
||||
return res.StatusCode, res.Header, out[:n]
|
||||
}
|
||||
|
||||
func TestAnnounceHappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, st := newTestServerWithHub(t)
|
||||
pub := newKeypair(t)
|
||||
status, _, body := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "alice", OS: "linux", Arch: "amd64",
|
||||
AgentVersion: "1.0", ResticVersion: "0.17",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
})
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("status: %d body=%s", status, body)
|
||||
}
|
||||
var ar announceResponse
|
||||
if err := json.Unmarshal(body, &ar); err != nil {
|
||||
t.Fatalf("unmarshal: %v body=%s", err, body)
|
||||
}
|
||||
if ar.PendingID == "" {
|
||||
t.Fatal("missing pending_id")
|
||||
}
|
||||
if !strings.HasPrefix(ar.Fingerprint, "SHA256:") {
|
||||
t.Fatalf("fingerprint shape: %q", ar.Fingerprint)
|
||||
}
|
||||
if ar.HostnameCollision {
|
||||
t.Fatal("first announce shouldn't be a collision")
|
||||
}
|
||||
// Row exists in the store.
|
||||
if _, err := st.GetPendingHost(context.Background(), ar.PendingID); err != nil {
|
||||
t.Fatalf("pending row missing: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRejectsBadKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "x", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString([]byte("too-short")),
|
||||
})
|
||||
if status != stdhttp.StatusBadRequest {
|
||||
t.Fatalf("status: got %d, want 400", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceHostnameCollisionFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
pub1 := newKeypair(t)
|
||||
pub2 := newKeypair(t)
|
||||
_, _, _ = postAnnounce(t, url, announceRequest{
|
||||
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub1),
|
||||
})
|
||||
status, _, body := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub2),
|
||||
})
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("status: %d", status)
|
||||
}
|
||||
var ar announceResponse
|
||||
_ = json.Unmarshal(body, &ar)
|
||||
if !ar.HostnameCollision {
|
||||
t.Fatal("expected hostname_collision=true on second announce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
// Lower the limit for the duration of this test (the limiter is
|
||||
// per-server-instance so we don't disturb parallel tests).
|
||||
prev := announceMaxPerMin
|
||||
announceMaxPerMin = 2
|
||||
t.Cleanup(func() { announceMaxPerMin = prev })
|
||||
|
||||
pub := newKeypair(t)
|
||||
body := announceRequest{
|
||||
Hostname: "rl-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
status, _, _ := postAnnounce(t, url, body)
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("call %d: status %d", i, status)
|
||||
}
|
||||
}
|
||||
status, _, _ := postAnnounce(t, url, body)
|
||||
if status != stdhttp.StatusTooManyRequests {
|
||||
t.Fatalf("3rd call: want 429, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceGlobalCap(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, st := newTestServerWithHub(t)
|
||||
prev := announceGlobalCap
|
||||
announceGlobalCap = 1
|
||||
t.Cleanup(func() { announceGlobalCap = prev })
|
||||
|
||||
// Pre-seed one row directly via the store so the cap is hit.
|
||||
pub := newKeypair(t)
|
||||
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
|
||||
ID: ulid.Make().String(), Hostname: "x", OS: "linux", Arch: "amd64",
|
||||
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
|
||||
AnnouncedFromIP: "127.0.0.1",
|
||||
FirstSeenAt: time.Now().UTC(),
|
||||
LastSeenAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "next", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(newKeypair(t)),
|
||||
})
|
||||
if status != stdhttp.StatusServiceUnavailable {
|
||||
t.Fatalf("want 503 over cap, got %d", status)
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,10 @@ type Server struct {
|
||||
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
|
||||
drainLocksMu sync.Mutex
|
||||
drainLocks map[string]*sync.Mutex
|
||||
|
||||
// announceRL is the per-source-IP token-bucket guarding
|
||||
// POST /api/agents/announce (P2-18). One process-local map.
|
||||
announceRL *announceLimiter
|
||||
}
|
||||
|
||||
// New builds a configured but not-yet-started server.
|
||||
@@ -67,7 +71,11 @@ func New(deps Deps) *Server {
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
})
|
||||
|
||||
s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)}
|
||||
s := &Server{
|
||||
deps: deps,
|
||||
drainLocks: make(map[string]*sync.Mutex),
|
||||
announceRL: newAnnounceLimiter(),
|
||||
}
|
||||
s.routes(r)
|
||||
|
||||
s.srv = &stdhttp.Server{
|
||||
@@ -92,6 +100,11 @@ func (s *Server) routes(r chi.Router) {
|
||||
// Agent enrollment (open endpoint — token is the credential).
|
||||
r.Post("/agents/enroll", s.handleAgentEnroll)
|
||||
|
||||
// Announce-and-approve enrolment (open endpoint — fingerprint
|
||||
// comparison in the UI is the gate). Per-IP rate-limited and
|
||||
// globally capped (P2-18).
|
||||
r.Post("/agents/announce", s.handleAnnounce)
|
||||
|
||||
// Operator → server (authenticated). Spec.md §6.1's
|
||||
// /hosts/{id}/enrollment-token (regenerate) lands when the
|
||||
// host page can call it; for now just the create endpoint.
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
-- 0011_pending_hosts.sql
|
||||
--
|
||||
-- P2-18: announce-and-approve enrolment.
|
||||
--
|
||||
-- Agents that don't have an enrolment token announce themselves
|
||||
-- with `POST /api/agents/announce`, persisting one row here. The
|
||||
-- admin sees them in the dashboard's Pending hosts panel and can
|
||||
-- accept (mints a real Host row + bearer) or reject (deletes the
|
||||
-- row + closes the agent's pending WS).
|
||||
--
|
||||
-- public_key is the agent's Ed25519 public key (32 raw bytes).
|
||||
-- fingerprint = "SHA256:" + hex(sha256(public_key)) — printed by
|
||||
-- the install script on the endpoint terminal so the operator can
|
||||
-- compare the two before clicking accept. This comparison is the
|
||||
-- load-bearing security gate for this flow.
|
||||
--
|
||||
-- expires_at is set to first_seen_at + 1h on insert; a sweeper
|
||||
-- goroutine (P2-18b) deletes rows past their expiry. Hostname
|
||||
-- collisions with existing or other pending rows are *not*
|
||||
-- prevented at the DB level — multiple announces with the same
|
||||
-- hostname are flagged in the UI so admin can pick the right one.
|
||||
|
||||
CREATE TABLE pending_hosts (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT NOT NULL,
|
||||
os TEXT NOT NULL,
|
||||
arch TEXT NOT NULL,
|
||||
agent_version TEXT NOT NULL,
|
||||
restic_version TEXT NOT NULL,
|
||||
public_key BLOB NOT NULL, -- 32-byte Ed25519
|
||||
fingerprint TEXT NOT NULL, -- "SHA256:hex(...)"
|
||||
announced_from_ip TEXT NOT NULL,
|
||||
first_seen_at TEXT NOT NULL,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX pending_hosts_expires ON pending_hosts(expires_at);
|
||||
CREATE INDEX pending_hosts_fingerprint ON pending_hosts(fingerprint);
|
||||
CREATE INDEX pending_hosts_hostname ON pending_hosts(hostname);
|
||||
@@ -0,0 +1,225 @@
|
||||
// pending_hosts.go — store layer for the announce-and-approve
|
||||
// enrolment queue (P2-18a). Rows live for at most 1h; a sweeper
|
||||
// deletes anything past expires_at.
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PendingHost mirrors the pending_hosts table row, plus the derived
|
||||
// HostnameCollision flag the API hands back to the agent so the
|
||||
// install script can warn the operator at announce time.
|
||||
type PendingHost struct {
|
||||
ID string
|
||||
Hostname string
|
||||
OS string
|
||||
Arch string
|
||||
AgentVersion string
|
||||
ResticVersion string
|
||||
PublicKey []byte // 32-byte Ed25519
|
||||
Fingerprint string // "SHA256:hex"
|
||||
AnnouncedFromIP string
|
||||
FirstSeenAt time.Time
|
||||
LastSeenAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// FingerprintForKey returns the canonical "SHA256:hex" fingerprint
|
||||
// the operator sees in the UI and on the endpoint terminal.
|
||||
func FingerprintForKey(pubKey []byte) string {
|
||||
sum := sha256.Sum256(pubKey)
|
||||
return "SHA256:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// CreatePendingHost inserts a new row. Caller has already validated
|
||||
// the public key length and rate limits.
|
||||
func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error {
|
||||
if ph.ID == "" || len(ph.PublicKey) == 0 {
|
||||
return errors.New("store: pending host id + public_key required")
|
||||
}
|
||||
if ph.Fingerprint == "" {
|
||||
ph.Fingerprint = FingerprintForKey(ph.PublicKey)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if ph.FirstSeenAt.IsZero() {
|
||||
ph.FirstSeenAt = now
|
||||
}
|
||||
ph.LastSeenAt = now
|
||||
if ph.ExpiresAt.IsZero() {
|
||||
ph.ExpiresAt = now.Add(time.Hour)
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO pending_hosts (
|
||||
id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion,
|
||||
ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP,
|
||||
ph.FirstSeenAt.Format(time.RFC3339Nano),
|
||||
ph.LastSeenAt.Format(time.RFC3339Nano),
|
||||
ph.ExpiresAt.Format(time.RFC3339Nano),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create pending host: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TouchPendingHost bumps last_seen_at on the named pending row,
|
||||
// extending its visibility in the dashboard while the agent's
|
||||
// pending WS stays open. Does NOT extend expires_at — the 1h cap
|
||||
// is firm.
|
||||
func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`,
|
||||
when.UTC().Format(time.RFC3339Nano), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPendingHost returns one row by ID. ErrNotFound on miss.
|
||||
func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE id = ?`, id)
|
||||
return scanPendingHost(row)
|
||||
}
|
||||
|
||||
// GetPendingHostByFingerprint resolves a row by its public key
|
||||
// fingerprint (used by the WS pending handler to look up which row
|
||||
// an incoming connection corresponds to).
|
||||
func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE fingerprint = ?`, fp)
|
||||
return scanPendingHost(row)
|
||||
}
|
||||
|
||||
// ListPendingHosts returns every non-expired row, newest first. The
|
||||
// caller passes `now` so tests can fast-forward.
|
||||
func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE expires_at > ?
|
||||
ORDER BY first_seen_at DESC`,
|
||||
now.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: list pending hosts: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
out := []PendingHost{}
|
||||
for rows.Next() {
|
||||
ph, err := scanPendingHostRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, *ph)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// CountPendingHosts returns the count of non-expired rows. Used for
|
||||
// the global cap (P2-18: refuse new announces past 100 in flight).
|
||||
func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`,
|
||||
now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: count pending hosts: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// CountPendingHostsByHostname returns the number of non-expired
|
||||
// pending rows that share the supplied hostname. Used by the
|
||||
// announce endpoint to set the hostname_collision flag in its
|
||||
// response.
|
||||
func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`,
|
||||
hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DeletePendingHost removes one row by ID. ErrNotFound on miss.
|
||||
func (s *Store) DeletePendingHost(ctx context.Context, id string) error {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM pending_hosts WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: delete pending host: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredPendingHosts removes every row whose expires_at is in
|
||||
// the past. Returns the number of rows deleted so the sweeper can
|
||||
// log non-zero events.
|
||||
func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM pending_hosts WHERE expires_at <= ?`,
|
||||
now.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: delete expired pending hosts: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ----- scan helpers --------------------------------------------------
|
||||
|
||||
type pendingHostScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPendingHost(row *sql.Row) (*PendingHost, error) {
|
||||
ph, err := scanPendingHostRow(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return ph, err
|
||||
}
|
||||
|
||||
func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) {
|
||||
var (
|
||||
ph PendingHost
|
||||
firstSeenAt, lastSeenAt, expiresAt string
|
||||
)
|
||||
if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch,
|
||||
&ph.AgentVersion, &ph.ResticVersion,
|
||||
&ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP,
|
||||
&firstSeenAt, &lastSeenAt, &expiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil {
|
||||
ph.FirstSeenAt = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil {
|
||||
ph.LastSeenAt = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil {
|
||||
ph.ExpiresAt = t
|
||||
}
|
||||
return &ph, nil
|
||||
}
|
||||
Reference in New Issue
Block a user