phase 1: WS transport, enrollment, agent that hellos and heartbeats
Lands the protocol layer end-to-end: an agent can be enrolled through the operator UI, store credentials, dial back to the server over WS, complete the protocol_version handshake, and stay connected with periodic heartbeats. Server side: - P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction, json envelope writer with a write mutex, reader, error envelopes. - P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage (10s deadline, protocol_version checked against api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on reject), main read loop, defer hub register/unregister. - P1-10 POST /api/agents/enroll consumes a one-time token, mints a persistent agent bearer (sha-256 stored), creates a host row. - P1-10 POST /api/enrollment-tokens (operator, session-auth) issues a 1h one-time token. - P1-11 hello upserts agent_version + restic_version + protocol_version on the host row, flips status to online. - P1-12 heartbeat touches last_seen_at; background sweeper marks hosts offline after 90s without one. - store: hosts table accessors, host_schedule_version, enrollment_tokens FK on consumed_host dropped (audit-only field; the token gets burned before the host row exists). Agent side: - P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml, atomic Save (tmp+fsync+rename), Enrolled() helper. - P1-15 internal/agent/wsclient: dial with bearer + optional TLS cert pinning (sha-256 of leaf), exponential backoff with jitter (1s → 60s cap), heartbeat goroutine, fatal handling for ErrProtocolTooOld. - P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo. - P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version collection. restic detected by `restic version` parse; absent restic doesn't block startup. - cmd/agent: -enroll-server / -enroll-token flags drive first-run enrollment then exit (so the install script can hand off to systemd to run the persistent service). End-to-end smoke verified: bootstrap → login → issue token → enroll → run agent → server logs `ws agent connected` with the right host_id and protocol_version 1. All tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
||||
type HandlerDeps struct {
|
||||
Hub *Hub
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
// AgentHandler is the http.Handler that owns /ws/agent. Agents
|
||||
// authenticate with `Authorization: Bearer <token>` (issued at
|
||||
// enrollment) before the WS upgrade.
|
||||
//
|
||||
// Lifecycle:
|
||||
// 1. Bearer token resolves to a Host row.
|
||||
// 2. Upgrade.
|
||||
// 3. First message must be `hello`; protocol_version checked here.
|
||||
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
|
||||
// host row; job/log/repo messages forward to the relevant
|
||||
// handlers (TODO: lands with P1-18 onward).
|
||||
// 5. On Read error or context cancel, mark host offline, unregister
|
||||
// from the hub.
|
||||
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
|
||||
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
host, ok := authenticateAgent(r, deps.Store)
|
||||
if !ok {
|
||||
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
|
||||
return
|
||||
}
|
||||
|
||||
c := NewConn(host.ID, conn)
|
||||
// Keep agents alive across NAT boxes; coder/websocket
|
||||
// auto-pings under the hood when configured. The default 60s
|
||||
// works fine for a 30s heartbeat cadence.
|
||||
|
||||
runAgentLoop(r.Context(), c, host.ID, deps)
|
||||
})
|
||||
}
|
||||
|
||||
// authenticateAgent returns the host that owns the bearer token in
|
||||
// the request, or (nil, false) if anything is amiss. The same
|
||||
// "false" path is used for missing header, malformed header, unknown
|
||||
// token — no information leak about why.
|
||||
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
|
||||
hdr := r.Header.Get("Authorization")
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(hdr, prefix) {
|
||||
return nil, false
|
||||
}
|
||||
token := strings.TrimPrefix(hdr, prefix)
|
||||
if token == "" {
|
||||
return nil, false
|
||||
}
|
||||
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return h, true
|
||||
}
|
||||
|
||||
// runAgentLoop is the per-connection driver. Returns when the socket
|
||||
// is closed for any reason. It owns the hub registration: register on
|
||||
// hello acceptance, unregister on exit.
|
||||
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
|
||||
// Stage 1: hello (with a tight deadline).
|
||||
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
hello, err := c.Read(helloCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
if hello.Type != api.MsgHello {
|
||||
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
|
||||
return
|
||||
}
|
||||
var helloPayload api.HelloPayload
|
||||
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
|
||||
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
|
||||
return
|
||||
}
|
||||
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
|
||||
c.SendError(ctx, api.ErrProtocolTooOld,
|
||||
fmt.Sprintf("agent protocol_version %d below minimum %d",
|
||||
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
|
||||
"https://restic-manager.example/docs/upgrade")
|
||||
return
|
||||
}
|
||||
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
|
||||
// Forward-compat is fine — newer agents talking to older
|
||||
// servers should accept their lower version. Just log it.
|
||||
slog.Info("ws agent newer than server",
|
||||
"host_id", hostID,
|
||||
"agent_proto", helloPayload.ProtocolVersion,
|
||||
"server_proto", api.CurrentProtocolVersion)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := deps.Store.MarkHostHello(ctx, hostID,
|
||||
helloPayload.AgentVersion, helloPayload.ResticVersion,
|
||||
helloPayload.ProtocolVersion, now); err != nil {
|
||||
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
|
||||
}
|
||||
|
||||
deps.Hub.Register(hostID, c)
|
||||
defer deps.Hub.Unregister(hostID, c)
|
||||
defer func() { _ = c.Close() }()
|
||||
|
||||
slog.Info("ws agent connected",
|
||||
"host_id", hostID,
|
||||
"agent_version", helloPayload.AgentVersion,
|
||||
"protocol_version", helloPayload.ProtocolVersion)
|
||||
|
||||
// Stage 2: main read loop.
|
||||
for {
|
||||
env, err := c.Read(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dispatchAgentMessage(ctx, c, hostID, env, deps)
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchAgentMessage routes a single envelope to its handler. Only
|
||||
// hello + heartbeat are wired up in Phase 1's first slice; the rest
|
||||
// land with P1-18+ (jobs) and P2 (schedules).
|
||||
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
|
||||
switch env.Type {
|
||||
case api.MsgHeartbeat:
|
||||
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
|
||||
|
||||
case api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished,
|
||||
api.MsgLogStream, api.MsgSnapshotsRpt, api.MsgRepoStats,
|
||||
api.MsgScheduleAck, api.MsgCommandResult:
|
||||
// TODO(P1-18+): persist + fan out to subscribed browsers.
|
||||
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
|
||||
|
||||
case api.MsgError:
|
||||
var ep api.ErrorPayload
|
||||
_ = env.UnmarshalPayload(&ep)
|
||||
slog.Warn("ws agent reported error", "host_id", hostID,
|
||||
"code", string(ep.Code), "message", ep.Message)
|
||||
|
||||
default:
|
||||
slog.Warn("ws unknown message type from agent",
|
||||
"type", env.Type, "host_id", hostID)
|
||||
}
|
||||
}
|
||||
|
||||
// MinHeartbeatInterval is a sanity floor — any agent reporting
|
||||
// heartbeats more often than this is misbehaving. (Spec says 30s.)
|
||||
const MinHeartbeatInterval = 5 * time.Second
|
||||
|
||||
// suppress unused-import false-positives if json drops out later
|
||||
var _ = json.Marshal
|
||||
Reference in New Issue
Block a user