server: P2-18b pending WS + admin accept/reject

GET /ws/agent/pending?pending_id=… runs an Ed25519 nonce-sign
handshake against the row's stored public key, then holds the
connection open. POST /api/pending-hosts/{id}/accept (admin)
mints a real Host row + bearer + AEAD-encrypted repo creds, pushes
the bearer down the open WS, deletes the pending row, and writes
a host.accept_pending audit entry. POST /api/pending-hosts/{id}/reject
closes the socket with code 4001 and audit-logs host.reject_pending.

In-memory pendingHub keyed by pending_id wires accept/reject to
their live socket.
This commit is contained in:
2026-05-04 11:07:32 +01:00
parent a8e6c9d6d7
commit 567561a6a3
3 changed files with 566 additions and 0 deletions
+349
View File
@@ -0,0 +1,349 @@
// pending_ws.go — /ws/agent/pending and the admin accept/reject
// endpoints for the announce-and-approve enrolment flow (P2-18b).
//
// Flow:
// 1. Agent has previously called POST /api/agents/announce, which
// returned its pending_id + fingerprint. Agent persists the
// keypair locally.
// 2. Agent connects to /ws/agent/pending?pending_id=… (no auth).
// Server reads the row, generates a 32-byte nonce, sends it.
// 3. Agent signs the nonce with its Ed25519 private key, sends the
// signature back. Server verifies; close on bad sig.
// 4. The connection sits open; the agent reads but doesn't write.
// 5. Admin clicks Accept: POST /api/pending-hosts/{id}/accept with
// the same repo-creds form the token-mint flow uses. Server
// mints a Host row + bearer + encrypted creds, pushes one
// `enrolled` message down the open socket, closes cleanly.
// 6. Admin clicks Reject: socket closes with code 4001.
//
// Hub: a process-local in-memory map of pending_id → live conn so
// the accept/reject handlers can find the right socket. Sole
// instance lives on Server.pendingHub.
package http
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
stdhttp "net/http"
"sync"
"time"
"github.com/coder/websocket"
"github.com/go-chi/chi/v5"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// pendingConn is a single live /ws/agent/pending session. The accept
// handler sends the enrolment message via Send and closes the socket;
// the WS read loop is just waiting for that close.
type pendingConn struct {
conn *websocket.Conn
pendingID string
closed chan struct{}
}
// pendingHub is the in-memory map of pending_id → live socket.
type pendingHub struct {
mu sync.Mutex
conns map[string]*pendingConn
}
func newPendingHub() *pendingHub {
return &pendingHub{conns: map[string]*pendingConn{}}
}
func (h *pendingHub) register(pc *pendingConn) {
h.mu.Lock()
defer h.mu.Unlock()
// Replace any existing socket for the same pending_id (an agent
// reconnected) — close the old one cleanly first so its goroutine
// can exit.
if old, ok := h.conns[pc.pendingID]; ok {
_ = old.conn.Close(websocket.StatusNormalClosure, "superseded")
close(old.closed)
}
h.conns[pc.pendingID] = pc
}
func (h *pendingHub) unregister(pendingID string, pc *pendingConn) {
h.mu.Lock()
defer h.mu.Unlock()
if cur, ok := h.conns[pendingID]; ok && cur == pc {
delete(h.conns, pendingID)
}
}
func (h *pendingHub) get(pendingID string) *pendingConn {
h.mu.Lock()
defer h.mu.Unlock()
return h.conns[pendingID]
}
// nonceMessage is what the server sends first on /ws/agent/pending.
type nonceMessage struct {
Type string `json:"type"` // "nonce"
Nonce string `json:"nonce"` // base64
}
// signedNonceMessage is what the agent sends back.
type signedNonceMessage struct {
Type string `json:"type"` // "signed_nonce"
Signature string `json:"signature"` // base64
}
// enrolledMessage is what the server sends on accept. The agent
// persists the bearer to agent.yaml and exits announce mode.
type enrolledMessage struct {
Type string `json:"type"` // "enrolled"
HostID string `json:"host_id"`
Bearer string `json:"bearer"`
ServerID string `json:"server_id,omitempty"`
}
// handlePendingWS upgrades the WS, runs the nonce-sign handshake,
// registers the conn in the hub, and blocks until the conn is
// closed (by accept/reject or by the agent disconnecting).
func (s *Server) handlePendingWS(w stdhttp.ResponseWriter, r *stdhttp.Request) {
pendingID := r.URL.Query().Get("pending_id")
if pendingID == "" {
stdhttp.Error(w, "missing pending_id", stdhttp.StatusBadRequest)
return
}
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
if err != nil {
stdhttp.Error(w, "pending host not found", stdhttp.StatusNotFound)
return
}
if time.Now().UTC().After(row.ExpiresAt) {
stdhttp.Error(w, "pending host expired", stdhttp.StatusGone)
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
// Same-origin defaults are safe: the agent isn't a browser.
InsecureSkipVerify: true,
})
if err != nil {
slog.Warn("pending ws: accept", "pending_id", pendingID, "err", err)
return
}
// Generate + send nonce.
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
_ = conn.Close(websocket.StatusInternalError, "nonce gen")
return
}
nm := nonceMessage{Type: "nonce", Nonce: base64.StdEncoding.EncodeToString(nonce)}
raw, _ := json.Marshal(nm)
wctx, wcancel := context.WithTimeout(r.Context(), 5*time.Second)
if err := conn.Write(wctx, websocket.MessageText, raw); err != nil {
wcancel()
_ = conn.Close(websocket.StatusInternalError, "send nonce")
return
}
wcancel()
// Read signed nonce back.
rctx, rcancel := context.WithTimeout(r.Context(), 30*time.Second)
mt, body, err := conn.Read(rctx)
rcancel()
if err != nil || mt != websocket.MessageText {
_ = conn.Close(websocket.StatusPolicyViolation, "no signed nonce")
return
}
var sig signedNonceMessage
if err := json.Unmarshal(body, &sig); err != nil || sig.Type != "signed_nonce" {
_ = conn.Close(websocket.StatusPolicyViolation, "bad signed nonce shape")
return
}
sigBytes, err := base64.StdEncoding.DecodeString(sig.Signature)
if err != nil {
_ = conn.Close(websocket.StatusPolicyViolation, "bad signature b64")
return
}
if !ed25519.Verify(row.PublicKey, nonce, sigBytes) {
_ = conn.Close(websocket.StatusPolicyViolation, "signature does not verify")
return
}
// Touch the row so the dashboard knows the agent is live.
_ = s.deps.Store.TouchPendingHost(context.Background(), pendingID, time.Now().UTC())
// Register and block until close.
pc := &pendingConn{conn: conn, pendingID: pendingID, closed: make(chan struct{})}
s.pendingHub.register(pc)
defer s.pendingHub.unregister(pendingID, pc)
// Read loop: we don't expect any further frames from the agent.
// If the agent closes, we exit.
go func() {
for {
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
_, _, err := conn.Read(ctx)
cancel()
if err != nil {
close(pc.closed)
return
}
}
}()
<-pc.closed
}
// acceptForm is the admin form for POST /api/pending-hosts/{id}/accept.
// repo_password may be omitted only when the host already has admin-
// supplied creds elsewhere — we don't currently model that. For now,
// require all three.
type acceptForm struct {
RepoURL string `json:"repo_url"`
RepoUsername string `json:"repo_username"`
RepoPassword string `json:"repo_password"`
}
// handleAcceptPendingHost mints a real Host row + bearer + encrypted
// repo creds and pushes the bearer down the agent's open pending WS.
// Admin-auth required.
func (s *Server) handleAcceptPendingHost(w stdhttp.ResponseWriter, r *stdhttp.Request) {
user, ok := s.requireUser(r)
if !ok {
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
return
}
pendingID := chi.URLParam(r, "id")
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
if err != nil {
writeJSONError(w, stdhttp.StatusNotFound, "pending_not_found", "")
return
}
pc := s.pendingHub.get(pendingID)
if pc == nil {
writeJSONError(w, stdhttp.StatusConflict, "agent_not_connected",
"the pending agent is not currently connected; ask it to retry")
return
}
var form acceptForm
// Accept either JSON or form-urlencoded so HTMX-style POST works.
if r.Header.Get("Content-Type") == "application/json" {
if err := json.NewDecoder(r.Body).Decode(&form); err != nil {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
return
}
} else {
if err := r.ParseForm(); err != nil {
writeJSONError(w, stdhttp.StatusBadRequest, "bad_form", err.Error())
return
}
form.RepoURL = r.PostForm.Get("repo_url")
form.RepoUsername = r.PostForm.Get("repo_username")
form.RepoPassword = r.PostForm.Get("repo_password")
}
if form.RepoURL == "" || form.RepoPassword == "" {
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
"repo_url and repo_password are required")
return
}
// Mint persistent bearer + Host row.
hostID := ulid.Make().String()
token, err := auth.NewToken()
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
host := store.Host{
ID: hostID, Name: row.Hostname, OS: row.OS, Arch: row.Arch,
AgentVersion: row.AgentVersion, ResticVersion: row.ResticVersion,
EnrolledAt: time.Now().UTC(),
}
if err := s.deps.Store.CreateHost(r.Context(), host, auth.HashToken(token), ""); err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
// Encrypt + persist repo creds.
enc, err := s.encryptRepoCreds(repoCredsBlob(form), []byte("host:"+hostID))
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, enc); err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
// Drop the pending row.
if err := s.deps.Store.DeletePendingHost(r.Context(), pendingID); err != nil {
slog.Warn("accept pending: delete row", "pending_id", pendingID, "err", err)
}
// Push enrolled message + close the pending WS.
enrolled := enrolledMessage{Type: "enrolled", HostID: hostID, Bearer: token}
raw, _ := json.Marshal(enrolled)
wctx, wcancel := context.WithTimeout(r.Context(), 5*time.Second)
if err := pc.conn.Write(wctx, websocket.MessageText, raw); err != nil {
slog.Warn("accept pending: write enrolled", "pending_id", pendingID, "err", err)
}
wcancel()
_ = pc.conn.Close(websocket.StatusNormalClosure, "accepted")
// Audit.
uid := user.ID
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(),
UserID: &uid,
Actor: "user",
Action: "host.accept_pending",
TargetKind: ptr("host"),
TargetID: &hostID,
TS: time.Now().UTC(),
})
writeJSON(w, stdhttp.StatusOK, map[string]any{
"host_id": hostID,
"fingerprint": row.Fingerprint,
})
}
// handleRejectPendingHost deletes the pending row and closes any
// open WS for it. Admin-auth required.
func (s *Server) handleRejectPendingHost(w stdhttp.ResponseWriter, r *stdhttp.Request) {
user, ok := s.requireUser(r)
if !ok {
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
return
}
pendingID := chi.URLParam(r, "id")
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
w.WriteHeader(stdhttp.StatusNoContent)
return
}
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
if pc := s.pendingHub.get(pendingID); pc != nil {
_ = pc.conn.Close(4001, "rejected")
}
if err := s.deps.Store.DeletePendingHost(r.Context(), pendingID); err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
return
}
uid := user.ID
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(),
UserID: &uid,
Actor: "user",
Action: "host.reject_pending",
TargetKind: ptr("pending_host"),
TargetID: &row.ID,
TS: time.Now().UTC(),
})
w.WriteHeader(stdhttp.StatusNoContent)
}
+203
View File
@@ -0,0 +1,203 @@
// pending_ws_test.go — end-to-end test of the announce → pending WS
// → admin accept → bearer push round trip (P2-18b/c).
package http
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
stdhttp "net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// TestPendingWSNonceSignAcceptFlow: simulate an agent. Announce →
// open pending WS → sign nonce → admin accept (with repo creds) →
// expect 'enrolled' message with bearer.
func TestPendingWSNonceSignAcceptFlow(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("ed25519: %v", err)
}
// Pre-seed pending row directly (bypass the announce HTTP path
// since announce coverage lives in announce_test.go).
pendingID := ulid.Make().String()
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
ID: pendingID, Hostname: "ann-host", OS: "linux", Arch: "amd64",
AgentVersion: "1.0", ResticVersion: "0.17",
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
AnnouncedFromIP: "127.0.0.1",
FirstSeenAt: time.Now().UTC(),
LastSeenAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(time.Hour),
}); err != nil {
t.Fatalf("seed: %v", err)
}
// Open the pending WS.
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID
dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer dialCancel()
c, res, err := websocket.Dial(dialCtx, wsURL, nil)
if err != nil {
t.Fatalf("dial pending ws: %v", err)
}
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
t.Cleanup(func() { _ = c.CloseNow() })
// Read nonce.
rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second)
_, raw, err := c.Read(rctx)
rcancel()
if err != nil {
t.Fatalf("read nonce: %v", err)
}
var nm nonceMessage
if err := json.Unmarshal(raw, &nm); err != nil {
t.Fatalf("unmarshal nonce: %v", err)
}
nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce)
// Sign + reply.
sig := ed25519.Sign(priv, nonce)
reply, _ := json.Marshal(signedNonceMessage{
Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig),
})
wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := c.Write(wctx, websocket.MessageText, reply); err != nil {
wcancel()
t.Fatalf("write signed nonce: %v", err)
}
wcancel()
// Wait briefly so the server's hub.register completes before we
// fire accept.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if srv.pendingHub.get(pendingID) != nil {
break
}
time.Sleep(20 * time.Millisecond)
}
// Admin POST accept (form-encoded, with cookie).
cookie := loginAsAdmin(t, st)
form := url.Values{
"repo_url": {"rest:http://r/x"},
"repo_username": {"u"},
"repo_password": {"p"},
}
req, _ := stdhttp.NewRequest("POST",
ts.URL+"/api/pending-hosts/"+pendingID+"/accept",
strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(cookie)
resAccept, err := stdhttp.DefaultClient.Do(req)
if err != nil {
t.Fatalf("accept: %v", err)
}
defer resAccept.Body.Close()
if resAccept.StatusCode != stdhttp.StatusOK {
t.Fatalf("accept status: %d", resAccept.StatusCode)
}
// Expect 'enrolled' message + close.
rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second)
_, raw2, err := c.Read(rctx2)
rcancel2()
if err != nil {
t.Fatalf("read enrolled: %v", err)
}
var em enrolledMessage
if err := json.Unmarshal(raw2, &em); err != nil {
t.Fatalf("unmarshal enrolled: %v", err)
}
if em.Type != "enrolled" || em.Bearer == "" || em.HostID == "" {
t.Fatalf("enrolled payload bad: %+v", em)
}
// Pending row should be gone.
if _, err := st.GetPendingHost(context.Background(), pendingID); err == nil {
t.Error("pending row should have been deleted on accept")
}
// Real host row should exist.
if _, err := st.GetHost(context.Background(), em.HostID); err != nil {
t.Errorf("host row not created: %v", err)
}
}
// TestPendingWSBadSignatureClosed: server closes the WS when the
// signature does not verify against the row's public key.
func TestPendingWSBadSignatureClosed(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
_ = srv
// Two distinct keypairs — agent signs with the wrong one.
pubReal, _, _ := ed25519.GenerateKey(rand.Reader)
_, privAttacker, _ := ed25519.GenerateKey(rand.Reader)
pendingID := ulid.Make().String()
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
ID: pendingID, Hostname: "bad-host", OS: "linux", Arch: "amd64",
PublicKey: pubReal, Fingerprint: store.FingerprintForKey(pubReal),
AnnouncedFromIP: "127.0.0.1",
FirstSeenAt: time.Now().UTC(),
LastSeenAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(time.Hour),
}); err != nil {
t.Fatalf("seed: %v", err)
}
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID
dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer dialCancel()
c, res, err := websocket.Dial(dialCtx, wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
defer func() { _ = c.CloseNow() }()
// Read nonce.
rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second)
_, raw, _ := c.Read(rctx)
rcancel()
var nm nonceMessage
_ = json.Unmarshal(raw, &nm)
nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce)
// Sign with the wrong key.
sig := ed25519.Sign(privAttacker, nonce)
reply, _ := json.Marshal(signedNonceMessage{
Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig),
})
wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = c.Write(wctx, websocket.MessageText, reply)
wcancel()
// Server should close. Read until error.
rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second)
_, _, err = c.Read(rctx2)
rcancel2()
if err == nil {
t.Fatal("expected ws to close on bad signature")
}
}
+14
View File
@@ -53,6 +53,11 @@ type Server struct {
// announceRL is the per-source-IP token-bucket guarding
// POST /api/agents/announce (P2-18). One process-local map.
announceRL *announceLimiter
// pendingHub holds live /ws/agent/pending sockets keyed by
// pending_id so the accept/reject handlers can push the bearer
// or close cleanly (P2-18b).
pendingHub *pendingHub
}
// New builds a configured but not-yet-started server.
@@ -75,6 +80,7 @@ func New(deps Deps) *Server {
deps: deps,
drainLocks: make(map[string]*sync.Mutex),
announceRL: newAnnounceLimiter(),
pendingHub: newPendingHub(),
}
s.routes(r)
@@ -105,6 +111,10 @@ func (s *Server) routes(r chi.Router) {
// globally capped (P2-18).
r.Post("/agents/announce", s.handleAnnounce)
// Pending host management — admin-only (gated inside the handler).
r.Post("/pending-hosts/{id}/accept", s.handleAcceptPendingHost)
r.Post("/pending-hosts/{id}/reject", s.handleRejectPendingHost)
// Operator → server (authenticated). Spec.md §6.1's
// /hosts/{id}/enrollment-token (regenerate) lands when the
// host page can call it; for now just the create endpoint.
@@ -185,6 +195,10 @@ func (s *Server) routes(r chi.Router) {
r.Post("/hosts/{id}/run-backup", s.handleUIRunBackupGone)
r.Post("/hosts/{id}/init-repo", s.handleUIInitRepoGone)
// Pending-host WebSocket (announce-and-approve, P2-18b). Mounted
// before /ws/agent so the more-specific route matches first.
r.Get("/ws/agent/pending", s.handlePendingWS)
// Agent ↔ server WebSocket. Bearer-authenticated inside the handler.
if s.deps.Hub != nil {
r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{