store+server: P2-18a announce-and-approve schema + endpoint
migration 0011 adds pending_hosts table (id, hostname, public_key,
fingerprint, expiry). store/pending_hosts.go covers full CRUD plus
hostname-collision count + expired-row sweeper.
POST /api/agents/announce takes {hostname, os, arch, agent_version,
restic_version, public_key (base64)}, returns {pending_id,
fingerprint, hostname_collision}. Per-source-IP token-bucket
rate limit (10/min) + global cap of 100 in-flight rows. Public
key must be exactly 32 bytes (Ed25519).
This commit is contained in:
@@ -0,0 +1,211 @@
|
|||||||
|
// announce.go — POST /api/agents/announce: agent without a token
|
||||||
|
// announces itself with a freshly-minted Ed25519 public key, server
|
||||||
|
// stashes a pending_hosts row, admin compares fingerprints in the
|
||||||
|
// UI before accepting (P2-18a).
|
||||||
|
//
|
||||||
|
// Guards (per spec):
|
||||||
|
// - Per-source-IP token-bucket rate limit (10/min).
|
||||||
|
// - Global cap of 100 in-flight pending rows; further announces
|
||||||
|
// get 503 with a hint.
|
||||||
|
// - Public key must be exactly 32 bytes (Ed25519). Anything else
|
||||||
|
// 400-rejected.
|
||||||
|
//
|
||||||
|
// Hostname collisions are NOT rejected — multiple announces with
|
||||||
|
// the same hostname can be legitimate (re-running install on the
|
||||||
|
// same box). The UI flags collisions for the admin to disambiguate.
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
stdhttp "net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tunables — exposed as vars so tests can lower them. Defaults mirror
|
||||||
|
// the spec's recommendations.
|
||||||
|
var (
|
||||||
|
announceMaxPerMin = 10
|
||||||
|
announceGlobalCap = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// announceRequest is the wire shape POST /api/agents/announce takes.
|
||||||
|
// PublicKey is base64-std (no padding strip — stdlib decoder is
|
||||||
|
// lenient on padding for both forms).
|
||||||
|
type announceRequest struct {
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
OS string `json:"os"`
|
||||||
|
Arch string `json:"arch"`
|
||||||
|
AgentVersion string `json:"agent_version"`
|
||||||
|
ResticVersion string `json:"restic_version"`
|
||||||
|
PublicKey string `json:"public_key"` // base64
|
||||||
|
}
|
||||||
|
|
||||||
|
// announceResponse is what the agent gets back. Fingerprint is the
|
||||||
|
// canonical "SHA256:hex" the operator compares against the UI.
|
||||||
|
// HostnameCollision warns the install script that another pending
|
||||||
|
// row already uses the same hostname.
|
||||||
|
type announceResponse struct {
|
||||||
|
PendingID string `json:"pending_id"`
|
||||||
|
Fingerprint string `json:"fingerprint"`
|
||||||
|
HostnameCollision bool `json:"hostname_collision"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateBucket is a tiny per-IP token-bucket. last is the timestamp of
|
||||||
|
// the most recent refill; tokens is the current bucket level. Refill
|
||||||
|
// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin.
|
||||||
|
type rateBucket struct {
|
||||||
|
tokens float64
|
||||||
|
last time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// announceLimiter holds one bucket per source IP. Buckets are reaped
|
||||||
|
// lazily by a tiny grace period — we don't need true LRU cleanup
|
||||||
|
// because the bucket count is bounded by unique IPs in any given
|
||||||
|
// few minutes (small).
|
||||||
|
type announceLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
buckets map[string]*rateBucket
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnnounceLimiter() *announceLimiter {
|
||||||
|
return &announceLimiter{buckets: map[string]*rateBucket{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow returns true and consumes a token if the IP's bucket has at
|
||||||
|
// least one token, else returns false. Capacity = announceMaxPerMin.
|
||||||
|
func (l *announceLimiter) allow(ip string, now time.Time) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
cap := float64(announceMaxPerMin)
|
||||||
|
b, ok := l.buckets[ip]
|
||||||
|
if !ok {
|
||||||
|
b = &rateBucket{tokens: cap, last: now}
|
||||||
|
l.buckets[ip] = b
|
||||||
|
}
|
||||||
|
// Refill at cap tokens per minute.
|
||||||
|
elapsed := now.Sub(b.last).Seconds()
|
||||||
|
if elapsed > 0 {
|
||||||
|
b.tokens += (elapsed / 60.0) * cap
|
||||||
|
if b.tokens > cap {
|
||||||
|
b.tokens = cap
|
||||||
|
}
|
||||||
|
b.last = now
|
||||||
|
}
|
||||||
|
if b.tokens < 1.0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b.tokens--
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAnnounce is the public POST handler. Public — no auth.
|
||||||
|
func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// Rate limit by source IP. Strip port — the limit is per host,
|
||||||
|
// not per outbound source port.
|
||||||
|
ip := remoteIP(r)
|
||||||
|
if !s.announceRL.allow(ip, now) {
|
||||||
|
w.Header().Set("Retry-After", "60")
|
||||||
|
writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited",
|
||||||
|
"too many announces from this source; retry in a minute")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req announceRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" {
|
||||||
|
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||||
|
"hostname, os, arch, public_key are required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
// Try URL-safe / no-padding flavors before giving up.
|
||||||
|
if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil {
|
||||||
|
keyBytes = k2
|
||||||
|
} else {
|
||||||
|
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||||
|
"public_key must be base64")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(keyBytes) != ed25519.PublicKeySize {
|
||||||
|
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||||
|
"public_key must be 32 bytes (Ed25519)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global cap (cheap query — index on expires_at).
|
||||||
|
count, err := s.deps.Store.CountPendingHosts(r.Context(), now)
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if count >= announceGlobalCap {
|
||||||
|
writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached",
|
||||||
|
"too many in-flight pending hosts; ask an admin to clear the queue")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hostname collision flag (informational).
|
||||||
|
colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now)
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ph := &store.PendingHost{
|
||||||
|
ID: ulid.Make().String(),
|
||||||
|
Hostname: req.Hostname,
|
||||||
|
OS: req.OS,
|
||||||
|
Arch: req.Arch,
|
||||||
|
AgentVersion: req.AgentVersion,
|
||||||
|
ResticVersion: req.ResticVersion,
|
||||||
|
PublicKey: keyBytes,
|
||||||
|
Fingerprint: store.FingerprintForKey(keyBytes),
|
||||||
|
AnnouncedFromIP: ip,
|
||||||
|
FirstSeenAt: now,
|
||||||
|
LastSeenAt: now,
|
||||||
|
ExpiresAt: now.Add(time.Hour),
|
||||||
|
}
|
||||||
|
if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil {
|
||||||
|
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, stdhttp.StatusOK, announceResponse{
|
||||||
|
PendingID: ph.ID,
|
||||||
|
Fingerprint: ph.Fingerprint,
|
||||||
|
HostnameCollision: colls > 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus
|
||||||
|
// the X-Forwarded-For chain's first hop when behind a trusted proxy
|
||||||
|
// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup
|
||||||
|
// matches the framework's existing behavior elsewhere.
|
||||||
|
func remoteIP(r *stdhttp.Request) string {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// Take the first IP in the chain (closest to the original
|
||||||
|
// client) — same convention chi uses. Trim whitespace.
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
return strings.TrimSpace(parts[0])
|
||||||
|
}
|
||||||
|
addr := r.RemoteAddr
|
||||||
|
if i := strings.LastIndex(addr, ":"); i >= 0 {
|
||||||
|
return addr[:i]
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
// announce_test.go — covers POST /api/agents/announce: happy path,
|
||||||
|
// invalid public key, hostname collision flag, rate limit, global
|
||||||
|
// cap (P2-18a).
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
stdhttp "net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newKeypair(t *testing.T) ed25519.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ed25519: %v", err)
|
||||||
|
}
|
||||||
|
return pub
|
||||||
|
}
|
||||||
|
|
||||||
|
func postAnnounce(t *testing.T, url string, req announceRequest) (status int, header stdhttp.Header, body []byte) {
|
||||||
|
t.Helper()
|
||||||
|
b, _ := json.Marshal(req)
|
||||||
|
r, _ := stdhttp.NewRequest("POST", url+"/api/agents/announce", bytes.NewReader(b))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
res, err := stdhttp.DefaultClient.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
out := make([]byte, 4096)
|
||||||
|
n, _ := res.Body.Read(out)
|
||||||
|
return res.StatusCode, res.Header, out[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnnounceHappyPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, url, st := newTestServerWithHub(t)
|
||||||
|
pub := newKeypair(t)
|
||||||
|
status, _, body := postAnnounce(t, url, announceRequest{
|
||||||
|
Hostname: "alice", OS: "linux", Arch: "amd64",
|
||||||
|
AgentVersion: "1.0", ResticVersion: "0.17",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
})
|
||||||
|
if status != stdhttp.StatusOK {
|
||||||
|
t.Fatalf("status: %d body=%s", status, body)
|
||||||
|
}
|
||||||
|
var ar announceResponse
|
||||||
|
if err := json.Unmarshal(body, &ar); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v body=%s", err, body)
|
||||||
|
}
|
||||||
|
if ar.PendingID == "" {
|
||||||
|
t.Fatal("missing pending_id")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(ar.Fingerprint, "SHA256:") {
|
||||||
|
t.Fatalf("fingerprint shape: %q", ar.Fingerprint)
|
||||||
|
}
|
||||||
|
if ar.HostnameCollision {
|
||||||
|
t.Fatal("first announce shouldn't be a collision")
|
||||||
|
}
|
||||||
|
// Row exists in the store.
|
||||||
|
if _, err := st.GetPendingHost(context.Background(), ar.PendingID); err != nil {
|
||||||
|
t.Fatalf("pending row missing: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnnounceRejectsBadKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, url, _ := newTestServerWithHub(t)
|
||||||
|
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||||
|
Hostname: "x", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString([]byte("too-short")),
|
||||||
|
})
|
||||||
|
if status != stdhttp.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want 400", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnnounceHostnameCollisionFlag(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, url, _ := newTestServerWithHub(t)
|
||||||
|
pub1 := newKeypair(t)
|
||||||
|
pub2 := newKeypair(t)
|
||||||
|
_, _, _ = postAnnounce(t, url, announceRequest{
|
||||||
|
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub1),
|
||||||
|
})
|
||||||
|
status, _, body := postAnnounce(t, url, announceRequest{
|
||||||
|
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub2),
|
||||||
|
})
|
||||||
|
if status != stdhttp.StatusOK {
|
||||||
|
t.Fatalf("status: %d", status)
|
||||||
|
}
|
||||||
|
var ar announceResponse
|
||||||
|
_ = json.Unmarshal(body, &ar)
|
||||||
|
if !ar.HostnameCollision {
|
||||||
|
t.Fatal("expected hostname_collision=true on second announce")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnnounceRateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, url, _ := newTestServerWithHub(t)
|
||||||
|
// Lower the limit for the duration of this test (the limiter is
|
||||||
|
// per-server-instance so we don't disturb parallel tests).
|
||||||
|
prev := announceMaxPerMin
|
||||||
|
announceMaxPerMin = 2
|
||||||
|
t.Cleanup(func() { announceMaxPerMin = prev })
|
||||||
|
|
||||||
|
pub := newKeypair(t)
|
||||||
|
body := announceRequest{
|
||||||
|
Hostname: "rl-host", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
}
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
status, _, _ := postAnnounce(t, url, body)
|
||||||
|
if status != stdhttp.StatusOK {
|
||||||
|
t.Fatalf("call %d: status %d", i, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status, _, _ := postAnnounce(t, url, body)
|
||||||
|
if status != stdhttp.StatusTooManyRequests {
|
||||||
|
t.Fatalf("3rd call: want 429, got %d", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnnounceGlobalCap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, url, st := newTestServerWithHub(t)
|
||||||
|
prev := announceGlobalCap
|
||||||
|
announceGlobalCap = 1
|
||||||
|
t.Cleanup(func() { announceGlobalCap = prev })
|
||||||
|
|
||||||
|
// Pre-seed one row directly via the store so the cap is hit.
|
||||||
|
pub := newKeypair(t)
|
||||||
|
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
|
||||||
|
ID: ulid.Make().String(), Hostname: "x", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
|
||||||
|
AnnouncedFromIP: "127.0.0.1",
|
||||||
|
FirstSeenAt: time.Now().UTC(),
|
||||||
|
LastSeenAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed: %v", err)
|
||||||
|
}
|
||||||
|
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||||
|
Hostname: "next", OS: "linux", Arch: "amd64",
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(newKeypair(t)),
|
||||||
|
})
|
||||||
|
if status != stdhttp.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("want 503 over cap, got %d", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -49,6 +49,10 @@ type Server struct {
|
|||||||
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
|
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
|
||||||
drainLocksMu sync.Mutex
|
drainLocksMu sync.Mutex
|
||||||
drainLocks map[string]*sync.Mutex
|
drainLocks map[string]*sync.Mutex
|
||||||
|
|
||||||
|
// announceRL is the per-source-IP token-bucket guarding
|
||||||
|
// POST /api/agents/announce (P2-18). One process-local map.
|
||||||
|
announceRL *announceLimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// New builds a configured but not-yet-started server.
|
// New builds a configured but not-yet-started server.
|
||||||
@@ -67,7 +71,11 @@ func New(deps Deps) *Server {
|
|||||||
w.WriteHeader(stdhttp.StatusNoContent)
|
w.WriteHeader(stdhttp.StatusNoContent)
|
||||||
})
|
})
|
||||||
|
|
||||||
s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)}
|
s := &Server{
|
||||||
|
deps: deps,
|
||||||
|
drainLocks: make(map[string]*sync.Mutex),
|
||||||
|
announceRL: newAnnounceLimiter(),
|
||||||
|
}
|
||||||
s.routes(r)
|
s.routes(r)
|
||||||
|
|
||||||
s.srv = &stdhttp.Server{
|
s.srv = &stdhttp.Server{
|
||||||
@@ -92,6 +100,11 @@ func (s *Server) routes(r chi.Router) {
|
|||||||
// Agent enrollment (open endpoint — token is the credential).
|
// Agent enrollment (open endpoint — token is the credential).
|
||||||
r.Post("/agents/enroll", s.handleAgentEnroll)
|
r.Post("/agents/enroll", s.handleAgentEnroll)
|
||||||
|
|
||||||
|
// Announce-and-approve enrolment (open endpoint — fingerprint
|
||||||
|
// comparison in the UI is the gate). Per-IP rate-limited and
|
||||||
|
// globally capped (P2-18).
|
||||||
|
r.Post("/agents/announce", s.handleAnnounce)
|
||||||
|
|
||||||
// Operator → server (authenticated). Spec.md §6.1's
|
// Operator → server (authenticated). Spec.md §6.1's
|
||||||
// /hosts/{id}/enrollment-token (regenerate) lands when the
|
// /hosts/{id}/enrollment-token (regenerate) lands when the
|
||||||
// host page can call it; for now just the create endpoint.
|
// host page can call it; for now just the create endpoint.
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
-- 0011_pending_hosts.sql
|
||||||
|
--
|
||||||
|
-- P2-18: announce-and-approve enrolment.
|
||||||
|
--
|
||||||
|
-- Agents that don't have an enrolment token announce themselves
|
||||||
|
-- with `POST /api/agents/announce`, persisting one row here. The
|
||||||
|
-- admin sees them in the dashboard's Pending hosts panel and can
|
||||||
|
-- accept (mints a real Host row + bearer) or reject (deletes the
|
||||||
|
-- row + closes the agent's pending WS).
|
||||||
|
--
|
||||||
|
-- public_key is the agent's Ed25519 public key (32 raw bytes).
|
||||||
|
-- fingerprint = "SHA256:" + hex(sha256(public_key)) — printed by
|
||||||
|
-- the install script on the endpoint terminal so the operator can
|
||||||
|
-- compare the two before clicking accept. This comparison is the
|
||||||
|
-- load-bearing security gate for this flow.
|
||||||
|
--
|
||||||
|
-- expires_at is set to first_seen_at + 1h on insert; a sweeper
|
||||||
|
-- goroutine (P2-18b) deletes rows past their expiry. Hostname
|
||||||
|
-- collisions with existing or other pending rows are *not*
|
||||||
|
-- prevented at the DB level — multiple announces with the same
|
||||||
|
-- hostname are flagged in the UI so admin can pick the right one.
|
||||||
|
|
||||||
|
CREATE TABLE pending_hosts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
hostname TEXT NOT NULL,
|
||||||
|
os TEXT NOT NULL,
|
||||||
|
arch TEXT NOT NULL,
|
||||||
|
agent_version TEXT NOT NULL,
|
||||||
|
restic_version TEXT NOT NULL,
|
||||||
|
public_key BLOB NOT NULL, -- 32-byte Ed25519
|
||||||
|
fingerprint TEXT NOT NULL, -- "SHA256:hex(...)"
|
||||||
|
announced_from_ip TEXT NOT NULL,
|
||||||
|
first_seen_at TEXT NOT NULL,
|
||||||
|
last_seen_at TEXT NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX pending_hosts_expires ON pending_hosts(expires_at);
|
||||||
|
CREATE INDEX pending_hosts_fingerprint ON pending_hosts(fingerprint);
|
||||||
|
CREATE INDEX pending_hosts_hostname ON pending_hosts(hostname);
|
||||||
@@ -0,0 +1,225 @@
|
|||||||
|
// pending_hosts.go — store layer for the announce-and-approve
|
||||||
|
// enrolment queue (P2-18a). Rows live for at most 1h; a sweeper
|
||||||
|
// deletes anything past expires_at.
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PendingHost mirrors the pending_hosts table row, plus the derived
|
||||||
|
// HostnameCollision flag the API hands back to the agent so the
|
||||||
|
// install script can warn the operator at announce time.
|
||||||
|
type PendingHost struct {
|
||||||
|
ID string
|
||||||
|
Hostname string
|
||||||
|
OS string
|
||||||
|
Arch string
|
||||||
|
AgentVersion string
|
||||||
|
ResticVersion string
|
||||||
|
PublicKey []byte // 32-byte Ed25519
|
||||||
|
Fingerprint string // "SHA256:hex"
|
||||||
|
AnnouncedFromIP string
|
||||||
|
FirstSeenAt time.Time
|
||||||
|
LastSeenAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintForKey returns the canonical "SHA256:hex" fingerprint
|
||||||
|
// the operator sees in the UI and on the endpoint terminal.
|
||||||
|
func FingerprintForKey(pubKey []byte) string {
|
||||||
|
sum := sha256.Sum256(pubKey)
|
||||||
|
return "SHA256:" + hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePendingHost inserts a new row. Caller has already validated
|
||||||
|
// the public key length and rate limits.
|
||||||
|
func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error {
|
||||||
|
if ph.ID == "" || len(ph.PublicKey) == 0 {
|
||||||
|
return errors.New("store: pending host id + public_key required")
|
||||||
|
}
|
||||||
|
if ph.Fingerprint == "" {
|
||||||
|
ph.Fingerprint = FingerprintForKey(ph.PublicKey)
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if ph.FirstSeenAt.IsZero() {
|
||||||
|
ph.FirstSeenAt = now
|
||||||
|
}
|
||||||
|
ph.LastSeenAt = now
|
||||||
|
if ph.ExpiresAt.IsZero() {
|
||||||
|
ph.ExpiresAt = now.Add(time.Hour)
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO pending_hosts (
|
||||||
|
id, hostname, os, arch, agent_version, restic_version,
|
||||||
|
public_key, fingerprint, announced_from_ip,
|
||||||
|
first_seen_at, last_seen_at, expires_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion,
|
||||||
|
ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP,
|
||||||
|
ph.FirstSeenAt.Format(time.RFC3339Nano),
|
||||||
|
ph.LastSeenAt.Format(time.RFC3339Nano),
|
||||||
|
ph.ExpiresAt.Format(time.RFC3339Nano),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: create pending host: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TouchPendingHost bumps last_seen_at on the named pending row,
|
||||||
|
// extending its visibility in the dashboard while the agent's
|
||||||
|
// pending WS stays open. Does NOT extend expires_at — the 1h cap
|
||||||
|
// is firm.
|
||||||
|
func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`,
|
||||||
|
when.UTC().Format(time.RFC3339Nano), id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingHost returns one row by ID. ErrNotFound on miss.
|
||||||
|
func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) {
|
||||||
|
row := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||||
|
public_key, fingerprint, announced_from_ip,
|
||||||
|
first_seen_at, last_seen_at, expires_at
|
||||||
|
FROM pending_hosts WHERE id = ?`, id)
|
||||||
|
return scanPendingHost(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingHostByFingerprint resolves a row by its public key
|
||||||
|
// fingerprint (used by the WS pending handler to look up which row
|
||||||
|
// an incoming connection corresponds to).
|
||||||
|
func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) {
|
||||||
|
row := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||||
|
public_key, fingerprint, announced_from_ip,
|
||||||
|
first_seen_at, last_seen_at, expires_at
|
||||||
|
FROM pending_hosts WHERE fingerprint = ?`, fp)
|
||||||
|
return scanPendingHost(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPendingHosts returns every non-expired row, newest first. The
|
||||||
|
// caller passes `now` so tests can fast-forward.
|
||||||
|
func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx,
|
||||||
|
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||||
|
public_key, fingerprint, announced_from_ip,
|
||||||
|
first_seen_at, last_seen_at, expires_at
|
||||||
|
FROM pending_hosts WHERE expires_at > ?
|
||||||
|
ORDER BY first_seen_at DESC`,
|
||||||
|
now.UTC().Format(time.RFC3339Nano))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("store: list pending hosts: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
out := []PendingHost{}
|
||||||
|
for rows.Next() {
|
||||||
|
ph, err := scanPendingHostRow(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, *ph)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountPendingHosts returns the count of non-expired rows. Used for
|
||||||
|
// the global cap (P2-18: refuse new announces past 100 in flight).
|
||||||
|
func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) {
|
||||||
|
var n int
|
||||||
|
err := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`,
|
||||||
|
now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("store: count pending hosts: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountPendingHostsByHostname returns the number of non-expired
|
||||||
|
// pending rows that share the supplied hostname. Used by the
|
||||||
|
// announce endpoint to set the hostname_collision flag in its
|
||||||
|
// response.
|
||||||
|
func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) {
|
||||||
|
var n int
|
||||||
|
err := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`,
|
||||||
|
hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePendingHost removes one row by ID. ErrNotFound on miss.
|
||||||
|
func (s *Store) DeletePendingHost(ctx context.Context, id string) error {
|
||||||
|
res, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM pending_hosts WHERE id = ?`, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: delete pending host: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteExpiredPendingHosts removes every row whose expires_at is in
|
||||||
|
// the past. Returns the number of rows deleted so the sweeper can
|
||||||
|
// log non-zero events.
|
||||||
|
func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
res, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM pending_hosts WHERE expires_at <= ?`,
|
||||||
|
now.UTC().Format(time.RFC3339Nano))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("store: delete expired pending hosts: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- scan helpers --------------------------------------------------
|
||||||
|
|
||||||
|
type pendingHostScanner interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanPendingHost(row *sql.Row) (*PendingHost, error) {
|
||||||
|
ph, err := scanPendingHostRow(row)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return ph, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) {
|
||||||
|
var (
|
||||||
|
ph PendingHost
|
||||||
|
firstSeenAt, lastSeenAt, expiresAt string
|
||||||
|
)
|
||||||
|
if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch,
|
||||||
|
&ph.AgentVersion, &ph.ResticVersion,
|
||||||
|
&ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP,
|
||||||
|
&firstSeenAt, &lastSeenAt, &expiresAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil {
|
||||||
|
ph.FirstSeenAt = t
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil {
|
||||||
|
ph.LastSeenAt = t
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil {
|
||||||
|
ph.ExpiresAt = t
|
||||||
|
}
|
||||||
|
return &ph, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user