phase 1: WS transport, enrollment, agent that hellos and heartbeats

Lands the protocol layer end-to-end: an agent can be enrolled
through the operator UI, store credentials, dial back to the server
over WS, complete the protocol_version handshake, and stay
connected with periodic heartbeats.

Server side:
- P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction,
  json envelope writer with a write mutex, reader, error envelopes.
- P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage
  (10s deadline, protocol_version checked against
  api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on
  reject), main read loop, defer hub register/unregister.
- P1-10 POST /api/agents/enroll consumes a one-time token, mints a
  persistent agent bearer (sha-256 stored), creates a host row.
- P1-10 POST /api/enrollment-tokens (operator, session-auth)
  issues a 1h one-time token.
- P1-11 hello upserts agent_version + restic_version +
  protocol_version on the host row, flips status to online.
- P1-12 heartbeat touches last_seen_at; background sweeper marks
  hosts offline after 90s without one.
- store: hosts table accessors, host_schedule_version,
  enrollment_tokens FK on consumed_host dropped (audit-only field;
  the token gets burned before the host row exists).

Agent side:
- P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml,
  atomic Save (tmp+fsync+rename), Enrolled() helper.
- P1-15 internal/agent/wsclient: dial with bearer + optional
  TLS cert pinning (sha-256 of leaf), exponential backoff with
  jitter (1s → 60s cap), heartbeat goroutine, fatal handling for
  ErrProtocolTooOld.
- P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo.
- P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version
  collection. restic detected by `restic version` parse; absent
  restic doesn't block startup.
- cmd/agent: -enroll-server / -enroll-token flags drive first-run
  enrollment then exit (so the install script can hand off to
  systemd to run the persistent service).

End-to-end smoke verified: bootstrap → login → issue token →
enroll → run agent → server logs `ws agent connected` with the
right host_id and protocol_version 1.

All tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:39:00 +01:00
parent df2c584b23
commit 9cc0caff1e
18 changed files with 1670 additions and 14 deletions
+110
View File
@@ -0,0 +1,110 @@
// Package config loads the agent's persistent configuration. After
// enrollment, the file holds the bearer token + server URL; it is
// only ever written via Save (which replaces atomically).
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// Config is the on-disk shape of the agent's config file.
type Config struct {
// ServerURL is the base URL of the control plane, e.g.
// https://restic.lab.example. The agent appends /ws/agent and
// /api/agents/enroll.
ServerURL string `yaml:"server_url"`
// AgentToken is the bearer credential issued at enrollment.
// Empty means "not yet enrolled."
AgentToken string `yaml:"agent_token"`
// HostID is what the server thinks this host is.
HostID string `yaml:"host_id"`
// CertPinSHA256 (optional) is the SHA-256 of the server's TLS
// cert. When set, the agent refuses to connect to a server
// whose cert hash doesn't match.
CertPinSHA256 string `yaml:"cert_pin_sha256,omitempty"`
// ResticPath overrides the auto-detected restic binary path.
ResticPath string `yaml:"restic_path,omitempty"`
// path is the file we loaded from. Used by Save.
path string `yaml:"-"`
}
// DefaultPath returns the canonical config path for the current OS.
// Phase 1 ships Linux only; Windows path lives in the spec for P2.
func DefaultPath() string {
return "/etc/restic-manager/agent.yaml"
}
// Load reads and parses the config file at path. A missing file is
// returned as an empty Config (not an error) — first-run agents
// haven't been enrolled yet.
func Load(path string) (*Config, error) {
c := &Config{path: path}
body, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return c, nil
}
return nil, fmt.Errorf("agent config: read %q: %w", path, err)
}
if err := yaml.Unmarshal(body, c); err != nil {
return nil, fmt.Errorf("agent config: parse %q: %w", path, err)
}
c.path = path
return c, nil
}
// Save writes the config back atomically: write to <path>.tmp, fsync,
// rename. A crash mid-write either leaves the old file or the new one,
// never a half-written one.
func (c *Config) Save() error {
if c.path == "" {
return fmt.Errorf("agent config: no path set")
}
dir := filepath.Dir(c.path)
if err := os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("agent config: mkdir %q: %w", dir, err)
}
body, err := yaml.Marshal(c)
if err != nil {
return fmt.Errorf("agent config: marshal: %w", err)
}
tmp := c.path + ".tmp"
f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("agent config: create tmp: %w", err)
}
if _, err := f.Write(body); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("agent config: write tmp: %w", err)
}
if err := f.Sync(); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("agent config: fsync tmp: %w", err)
}
if err := f.Close(); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("agent config: close tmp: %w", err)
}
if err := os.Rename(tmp, c.path); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("agent config: rename: %w", err)
}
return nil
}
// Enrolled reports whether the agent has finished enrollment.
func (c *Config) Enrolled() bool {
return c.AgentToken != "" && c.HostID != "" && c.ServerURL != ""
}
+78
View File
@@ -0,0 +1,78 @@
// Package sysinfo collects host metadata at agent startup: OS, arch,
// hostname, restic version. The agent sends this in `hello` so the
// server's Host row stays current.
package sysinfo
import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"strings"
"time"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Snapshot is the bundle of metadata reported in `hello`.
type Snapshot struct {
Hostname string
OS api.HostOS
Arch api.HostArch
ResticVersion string
ProtocolVersion int
BootTime time.Time
}
// Collect probes the running host. resticPath, if non-empty,
// overrides PATH lookup.
func Collect(ctx context.Context, resticPath string) (Snapshot, error) {
hn, err := os.Hostname()
if err != nil {
return Snapshot{}, fmt.Errorf("sysinfo: hostname: %w", err)
}
osTag := api.HostOS(runtime.GOOS)
archTag := api.HostArch(runtime.GOARCH)
resticVer, _ := detectResticVersion(ctx, resticPath) // empty on failure is fine
return Snapshot{
Hostname: hn,
OS: osTag,
Arch: archTag,
ResticVersion: resticVer,
ProtocolVersion: api.CurrentProtocolVersion,
}, nil
}
// detectResticVersion runs `restic version` and parses the first line.
// Output looks like:
// restic 0.17.1 compiled with go1.22.5 on linux/amd64
// Returns the version token (e.g. "0.17.1") or "" if restic isn't
// found. We never block startup on a missing restic — the operator
// might not have installed it yet, and the agent should still be
// able to connect and report.
func detectResticVersion(ctx context.Context, override string) (string, error) {
bin := override
if bin == "" {
var err error
bin, err = exec.LookPath("restic")
if err != nil {
return "", err
}
}
versionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
out, err := exec.CommandContext(versionCtx, bin, "version").Output()
if err != nil {
return "", err
}
first := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0]
parts := strings.Fields(first)
if len(parts) >= 2 && parts[0] == "restic" {
return parts[1], nil
}
return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first)
}
+246
View File
@@ -0,0 +1,246 @@
// Package wsclient is the agent's outbound WebSocket connection to
// the control plane: dial with bearer auth, perform the hello
// handshake, send heartbeats, dispatch server-pushed commands.
//
// The Run loop is a forever-loop with exponential backoff on dial
// failures, capped at 60s. Disconnected agents keep retrying.
package wsclient
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand"
stdhttp "net/http"
"net/url"
"strings"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Config holds the agent's connection settings.
type Config struct {
ServerURL string
AgentToken string
HostID string
CertPinSHA256 string // hex; empty disables pinning
HeartbeatPeriod time.Duration
HelloPayload api.HelloPayload
}
// Handler is invoked for every server-sent message. The agent's main
// program supplies one that knows how to dispatch command.run etc.
// to the runner package.
type Handler func(ctx context.Context, env api.Envelope) error
// Run keeps the agent connected indefinitely. Returns when ctx is
// cancelled. Errors during a single connection attempt are logged and
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
func Run(ctx context.Context, cfg Config, handle Handler) error {
if cfg.HeartbeatPeriod <= 0 {
cfg.HeartbeatPeriod = 30 * time.Second
}
backoff := newBackoff(time.Second, 60*time.Second)
for {
err := connectOnce(ctx, cfg, handle)
if errors.Is(err, context.Canceled) {
return nil
}
if err != nil {
slog.Warn("ws agent disconnect", "err", err)
}
if err := sleepCtx(ctx, backoff.next()); err != nil {
return nil
}
}
}
// connectOnce performs one full connection lifecycle: dial → hello →
// heartbeat loop + read loop → close. Returns when either side closes
// the socket.
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
wsURL, err := buildWSURL(cfg.ServerURL)
if err != nil {
return fmt.Errorf("ws agent: bad server url: %w", err)
}
dialOpts := &websocket.DialOptions{
HTTPHeader: stdhttp.Header{
"Authorization": []string{"Bearer " + cfg.AgentToken},
},
}
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
dialOpts.HTTPClient = &stdhttp.Client{
Transport: &stdhttp.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
},
},
}
}
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts)
cancel()
if err != nil {
return fmt.Errorf("dial: %w", err)
}
defer conn.CloseNow() //nolint:errcheck
// Send hello.
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
if err != nil {
return fmt.Errorf("marshal hello: %w", err)
}
if err := writeEnv(ctx, conn, helloEnv); err != nil {
return fmt.Errorf("write hello: %w", err)
}
slog.Info("ws agent connected", "server", wsURL)
// Heartbeat goroutine.
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
defer cancelHeartbeat()
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
// Read loop. A read error returns and closes the conn.
for {
mt, raw, err := conn.Read(ctx)
if err != nil {
return fmt.Errorf("read: %w", err)
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
slog.Warn("ws agent: bad envelope from server", "err", err)
continue
}
if env.Type == api.MsgError {
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Error("ws agent: server reported error",
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
// protocol_too_old is fatal — keep retrying won't help.
if ep.Code == api.ErrProtocolTooOld {
return fmt.Errorf("protocol too old: %s", ep.Message)
}
continue
}
if handle != nil {
if err := handle(ctx, env); err != nil {
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
}
}
}
}
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
t := time.NewTicker(period)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
env, err := api.Marshal(api.MsgHeartbeat, "",
api.HeartbeatPayload{SentAt: time.Now().UTC()})
if err != nil {
continue
}
if err := writeEnv(ctx, conn, env); err != nil {
slog.Warn("ws agent: heartbeat write failed", "err", err)
return
}
}
}
}
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
raw, err := json.Marshal(env)
if err != nil {
return err
}
return conn.Write(ctx, websocket.MessageText, raw)
}
func buildWSURL(serverURL string) (string, error) {
u, err := url.Parse(serverURL)
if err != nil {
return "", err
}
switch u.Scheme {
case "https":
u.Scheme = "wss"
case "http":
u.Scheme = "ws"
case "ws", "wss":
// already correct
default:
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
}
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
return u.String(), nil
}
// pinChecker returns a VerifyPeerCertificate callback that requires
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
// to the OS root verification (we don't replace it).
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("ws agent: no peer certs")
}
got := sha256Hex(rawCerts[0])
if got != wantHex {
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
}
return nil
}
}
func sha256Hex(b []byte) string {
// avoid pulling in crypto/sha256 in this top-level file twice;
// indirection through hex-encode is the classic shape.
h := newSHA256()
h.Write(b)
return hex.EncodeToString(h.Sum(nil))
}
// ----- backoff -------------------------------------------------------
type backoff struct {
cur, max time.Duration
}
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
func (b *backoff) next() time.Duration {
d := b.cur
// 20% jitter, deterministic-enough randomness.
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
b.cur *= 2
if b.cur > b.max {
b.cur = b.max
}
return d + jitter
}
func sleepCtx(ctx context.Context, d time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(d):
return nil
}
}
+67
View File
@@ -0,0 +1,67 @@
package wsclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
stdhttp "net/http"
"strings"
"time"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// EnrollRequest is what we POST to /api/agents/enroll.
type EnrollRequest struct {
Token string `json:"token"`
HostName string `json:"hostname"`
OS api.HostOS `json:"os"`
Arch api.HostArch `json:"arch"`
AgentVersion string `json:"agent_version"`
ResticVersion string `json:"restic_version"`
}
// EnrollResponse is what the server hands back.
type EnrollResponse struct {
HostID string `json:"host_id"`
AgentToken string `json:"agent_token"`
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
}
// Enroll exchanges a one-time enrollment token for persistent agent
// credentials. Called by the install script on first run.
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("agent enroll: marshal: %w", err)
}
postURL := strings.TrimRight(serverURL, "/") + "/api/agents/enroll"
httpReq, err := stdhttp.NewRequestWithContext(ctx, stdhttp.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("agent enroll: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
client := &stdhttp.Client{Timeout: 30 * time.Second}
res, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("agent enroll: post: %w", err)
}
defer res.Body.Close()
rawRes, _ := io.ReadAll(res.Body)
if res.StatusCode != stdhttp.StatusCreated {
return nil, fmt.Errorf("agent enroll: server returned %d: %s",
res.StatusCode, rawRes)
}
var er EnrollResponse
if err := json.Unmarshal(rawRes, &er); err != nil {
return nil, fmt.Errorf("agent enroll: parse response: %w", err)
}
if er.AgentToken == "" || er.HostID == "" {
return nil, fmt.Errorf("agent enroll: incomplete response: %+v", er)
}
return &er, nil
}
+8
View File
@@ -0,0 +1,8 @@
package wsclient
import (
"crypto/sha256"
"hash"
)
func newSHA256() hash.Hash { return sha256.New() }