phase 1: WS transport, enrollment, agent that hellos and heartbeats

Lands the protocol layer end-to-end: an agent can be enrolled
through the operator UI, store credentials, dial back to the server
over WS, complete the protocol_version handshake, and stay
connected with periodic heartbeats.

Server side:
- P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction,
  json envelope writer with a write mutex, reader, error envelopes.
- P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage
  (10s deadline, protocol_version checked against
  api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on
  reject), main read loop, defer hub register/unregister.
- P1-10 POST /api/agents/enroll consumes a one-time token, mints a
  persistent agent bearer (sha-256 stored), creates a host row.
- P1-10 POST /api/enrollment-tokens (operator, session-auth)
  issues a 1h one-time token.
- P1-11 hello upserts agent_version + restic_version +
  protocol_version on the host row, flips status to online.
- P1-12 heartbeat touches last_seen_at; background sweeper marks
  hosts offline after 90s without one.
- store: hosts table accessors, host_schedule_version,
  enrollment_tokens FK on consumed_host dropped (audit-only field;
  the token gets burned before the host row exists).

Agent side:
- P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml,
  atomic Save (tmp+fsync+rename), Enrolled() helper.
- P1-15 internal/agent/wsclient: dial with bearer + optional
  TLS cert pinning (sha-256 of leaf), exponential backoff with
  jitter (1s → 60s cap), heartbeat goroutine, fatal handling for
  ErrProtocolTooOld.
- P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo.
- P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version
  collection. restic detected by `restic version` parse; absent
  restic doesn't block startup.
- cmd/agent: -enroll-server / -enroll-token flags drive first-run
  enrollment then exit (so the install script can hand off to
  systemd to run the persistent service).

End-to-end smoke verified: bootstrap → login → issue token →
enroll → run agent → server logs `ws agent connected` with the
right host_id and protocol_version 1.

All tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:39:00 +01:00
parent df2c584b23
commit 9cc0caff1e
18 changed files with 1670 additions and 14 deletions
+183
View File
@@ -0,0 +1,183 @@
package ws
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
stdhttp "net/http"
"strings"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// HandlerDeps is the set of collaborators the agent WS handler needs.
type HandlerDeps struct {
Hub *Hub
Store *store.Store
}
// AgentHandler is the http.Handler that owns /ws/agent. Agents
// authenticate with `Authorization: Bearer <token>` (issued at
// enrollment) before the WS upgrade.
//
// Lifecycle:
// 1. Bearer token resolves to a Host row.
// 2. Upgrade.
// 3. First message must be `hello`; protocol_version checked here.
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
// host row; job/log/repo messages forward to the relevant
// handlers (TODO: lands with P1-18 onward).
// 5. On Read error or context cancel, mark host offline, unregister
// from the hub.
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
host, ok := authenticateAgent(r, deps.Store)
if !ok {
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
})
if err != nil {
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
return
}
c := NewConn(host.ID, conn)
// Keep agents alive across NAT boxes; coder/websocket
// auto-pings under the hood when configured. The default 60s
// works fine for a 30s heartbeat cadence.
runAgentLoop(r.Context(), c, host.ID, deps)
})
}
// authenticateAgent returns the host that owns the bearer token in
// the request, or (nil, false) if anything is amiss. The same
// "false" path is used for missing header, malformed header, unknown
// token — no information leak about why.
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
hdr := r.Header.Get("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(hdr, prefix) {
return nil, false
}
token := strings.TrimPrefix(hdr, prefix)
if token == "" {
return nil, false
}
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
if err != nil {
return nil, false
}
return h, true
}
// runAgentLoop is the per-connection driver. Returns when the socket
// is closed for any reason. It owns the hub registration: register on
// hello acceptance, unregister on exit.
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
// Stage 1: hello (with a tight deadline).
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
hello, err := c.Read(helloCtx)
cancel()
if err != nil {
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
_ = c.Close()
return
}
if hello.Type != api.MsgHello {
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
return
}
var helloPayload api.HelloPayload
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
return
}
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
c.SendError(ctx, api.ErrProtocolTooOld,
fmt.Sprintf("agent protocol_version %d below minimum %d",
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
"https://restic-manager.example/docs/upgrade")
return
}
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
// Forward-compat is fine — newer agents talking to older
// servers should accept their lower version. Just log it.
slog.Info("ws agent newer than server",
"host_id", hostID,
"agent_proto", helloPayload.ProtocolVersion,
"server_proto", api.CurrentProtocolVersion)
}
now := time.Now().UTC()
if err := deps.Store.MarkHostHello(ctx, hostID,
helloPayload.AgentVersion, helloPayload.ResticVersion,
helloPayload.ProtocolVersion, now); err != nil {
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
}
deps.Hub.Register(hostID, c)
defer deps.Hub.Unregister(hostID, c)
defer func() { _ = c.Close() }()
slog.Info("ws agent connected",
"host_id", hostID,
"agent_version", helloPayload.AgentVersion,
"protocol_version", helloPayload.ProtocolVersion)
// Stage 2: main read loop.
for {
env, err := c.Read(ctx)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
}
return
}
dispatchAgentMessage(ctx, c, hostID, env, deps)
}
}
// dispatchAgentMessage routes a single envelope to its handler. Only
// hello + heartbeat are wired up in Phase 1's first slice; the rest
// land with P1-18+ (jobs) and P2 (schedules).
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
switch env.Type {
case api.MsgHeartbeat:
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
case api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished,
api.MsgLogStream, api.MsgSnapshotsRpt, api.MsgRepoStats,
api.MsgScheduleAck, api.MsgCommandResult:
// TODO(P1-18+): persist + fan out to subscribed browsers.
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
case api.MsgError:
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Warn("ws agent reported error", "host_id", hostID,
"code", string(ep.Code), "message", ep.Message)
default:
slog.Warn("ws unknown message type from agent",
"type", env.Type, "host_id", hostID)
}
}
// MinHeartbeatInterval is a sanity floor — any agent reporting
// heartbeats more often than this is misbehaving. (Spec says 30s.)
const MinHeartbeatInterval = 5 * time.Second
// suppress unused-import false-positives if json drops out later
var _ = json.Marshal