phase 1: WS transport, enrollment, agent that hellos and heartbeats
Lands the protocol layer end-to-end: an agent can be enrolled through the operator UI, store credentials, dial back to the server over WS, complete the protocol_version handshake, and stay connected with periodic heartbeats. Server side: - P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction, json envelope writer with a write mutex, reader, error envelopes. - P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage (10s deadline, protocol_version checked against api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on reject), main read loop, defer hub register/unregister. - P1-10 POST /api/agents/enroll consumes a one-time token, mints a persistent agent bearer (sha-256 stored), creates a host row. - P1-10 POST /api/enrollment-tokens (operator, session-auth) issues a 1h one-time token. - P1-11 hello upserts agent_version + restic_version + protocol_version on the host row, flips status to online. - P1-12 heartbeat touches last_seen_at; background sweeper marks hosts offline after 90s without one. - store: hosts table accessors, host_schedule_version, enrollment_tokens FK on consumed_host dropped (audit-only field; the token gets burned before the host row exists). Agent side: - P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml, atomic Save (tmp+fsync+rename), Enrolled() helper. - P1-15 internal/agent/wsclient: dial with bearer + optional TLS cert pinning (sha-256 of leaf), exponential backoff with jitter (1s → 60s cap), heartbeat goroutine, fatal handling for ErrProtocolTooOld. - P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo. - P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version collection. restic detected by `restic version` parse; absent restic doesn't block startup. - cmd/agent: -enroll-server / -enroll-token flags drive first-run enrollment then exit (so the install script can hand off to systemd to run the persistent service). End-to-end smoke verified: bootstrap → login → issue token → enroll → run agent → server logs `ws agent connected` with the right host_id and protocol_version 1. All tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,110 @@
|
||||
// Package config loads the agent's persistent configuration. After
|
||||
// enrollment, the file holds the bearer token + server URL; it is
|
||||
// only ever written via Save (which replaces atomically).
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config is the on-disk shape of the agent's config file.
|
||||
type Config struct {
|
||||
// ServerURL is the base URL of the control plane, e.g.
|
||||
// https://restic.lab.example. The agent appends /ws/agent and
|
||||
// /api/agents/enroll.
|
||||
ServerURL string `yaml:"server_url"`
|
||||
|
||||
// AgentToken is the bearer credential issued at enrollment.
|
||||
// Empty means "not yet enrolled."
|
||||
AgentToken string `yaml:"agent_token"`
|
||||
|
||||
// HostID is what the server thinks this host is.
|
||||
HostID string `yaml:"host_id"`
|
||||
|
||||
// CertPinSHA256 (optional) is the SHA-256 of the server's TLS
|
||||
// cert. When set, the agent refuses to connect to a server
|
||||
// whose cert hash doesn't match.
|
||||
CertPinSHA256 string `yaml:"cert_pin_sha256,omitempty"`
|
||||
|
||||
// ResticPath overrides the auto-detected restic binary path.
|
||||
ResticPath string `yaml:"restic_path,omitempty"`
|
||||
|
||||
// path is the file we loaded from. Used by Save.
|
||||
path string `yaml:"-"`
|
||||
}
|
||||
|
||||
// DefaultPath returns the canonical config path for the current OS.
|
||||
// Phase 1 ships Linux only; Windows path lives in the spec for P2.
|
||||
func DefaultPath() string {
|
||||
return "/etc/restic-manager/agent.yaml"
|
||||
}
|
||||
|
||||
// Load reads and parses the config file at path. A missing file is
|
||||
// returned as an empty Config (not an error) — first-run agents
|
||||
// haven't been enrolled yet.
|
||||
func Load(path string) (*Config, error) {
|
||||
c := &Config{path: path}
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return c, nil
|
||||
}
|
||||
return nil, fmt.Errorf("agent config: read %q: %w", path, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(body, c); err != nil {
|
||||
return nil, fmt.Errorf("agent config: parse %q: %w", path, err)
|
||||
}
|
||||
c.path = path
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Save writes the config back atomically: write to <path>.tmp, fsync,
|
||||
// rename. A crash mid-write either leaves the old file or the new one,
|
||||
// never a half-written one.
|
||||
func (c *Config) Save() error {
|
||||
if c.path == "" {
|
||||
return fmt.Errorf("agent config: no path set")
|
||||
}
|
||||
dir := filepath.Dir(c.path)
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return fmt.Errorf("agent config: mkdir %q: %w", dir, err)
|
||||
}
|
||||
body, err := yaml.Marshal(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent config: marshal: %w", err)
|
||||
}
|
||||
|
||||
tmp := c.path + ".tmp"
|
||||
f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent config: create tmp: %w", err)
|
||||
}
|
||||
if _, err := f.Write(body); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: write tmp: %w", err)
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: fsync tmp: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: close tmp: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, c.path); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: rename: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enrolled reports whether the agent has finished enrollment.
|
||||
func (c *Config) Enrolled() bool {
|
||||
return c.AgentToken != "" && c.HostID != "" && c.ServerURL != ""
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
// Package sysinfo collects host metadata at agent startup: OS, arch,
|
||||
// hostname, restic version. The agent sends this in `hello` so the
|
||||
// server's Host row stays current.
|
||||
package sysinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// Snapshot is the bundle of metadata reported in `hello`.
|
||||
type Snapshot struct {
|
||||
Hostname string
|
||||
OS api.HostOS
|
||||
Arch api.HostArch
|
||||
ResticVersion string
|
||||
ProtocolVersion int
|
||||
BootTime time.Time
|
||||
}
|
||||
|
||||
// Collect probes the running host. resticPath, if non-empty,
|
||||
// overrides PATH lookup.
|
||||
func Collect(ctx context.Context, resticPath string) (Snapshot, error) {
|
||||
hn, err := os.Hostname()
|
||||
if err != nil {
|
||||
return Snapshot{}, fmt.Errorf("sysinfo: hostname: %w", err)
|
||||
}
|
||||
|
||||
osTag := api.HostOS(runtime.GOOS)
|
||||
archTag := api.HostArch(runtime.GOARCH)
|
||||
|
||||
resticVer, _ := detectResticVersion(ctx, resticPath) // empty on failure is fine
|
||||
return Snapshot{
|
||||
Hostname: hn,
|
||||
OS: osTag,
|
||||
Arch: archTag,
|
||||
ResticVersion: resticVer,
|
||||
ProtocolVersion: api.CurrentProtocolVersion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectResticVersion runs `restic version` and parses the first line.
|
||||
// Output looks like:
|
||||
// restic 0.17.1 compiled with go1.22.5 on linux/amd64
|
||||
// Returns the version token (e.g. "0.17.1") or "" if restic isn't
|
||||
// found. We never block startup on a missing restic — the operator
|
||||
// might not have installed it yet, and the agent should still be
|
||||
// able to connect and report.
|
||||
func detectResticVersion(ctx context.Context, override string) (string, error) {
|
||||
bin := override
|
||||
if bin == "" {
|
||||
var err error
|
||||
bin, err = exec.LookPath("restic")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
versionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
out, err := exec.CommandContext(versionCtx, bin, "version").Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
first := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0]
|
||||
parts := strings.Fields(first)
|
||||
if len(parts) >= 2 && parts[0] == "restic" {
|
||||
return parts[1], nil
|
||||
}
|
||||
return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first)
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
// Package wsclient is the agent's outbound WebSocket connection to
|
||||
// the control plane: dial with bearer auth, perform the hello
|
||||
// handshake, send heartbeats, dispatch server-pushed commands.
|
||||
//
|
||||
// The Run loop is a forever-loop with exponential backoff on dial
|
||||
// failures, capped at 60s. Disconnected agents keep retrying.
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// Config holds the agent's connection settings.
|
||||
type Config struct {
|
||||
ServerURL string
|
||||
AgentToken string
|
||||
HostID string
|
||||
CertPinSHA256 string // hex; empty disables pinning
|
||||
HeartbeatPeriod time.Duration
|
||||
HelloPayload api.HelloPayload
|
||||
}
|
||||
|
||||
// Handler is invoked for every server-sent message. The agent's main
|
||||
// program supplies one that knows how to dispatch command.run etc.
|
||||
// to the runner package.
|
||||
type Handler func(ctx context.Context, env api.Envelope) error
|
||||
|
||||
// Run keeps the agent connected indefinitely. Returns when ctx is
|
||||
// cancelled. Errors during a single connection attempt are logged and
|
||||
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
|
||||
func Run(ctx context.Context, cfg Config, handle Handler) error {
|
||||
if cfg.HeartbeatPeriod <= 0 {
|
||||
cfg.HeartbeatPeriod = 30 * time.Second
|
||||
}
|
||||
|
||||
backoff := newBackoff(time.Second, 60*time.Second)
|
||||
for {
|
||||
err := connectOnce(ctx, cfg, handle)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("ws agent disconnect", "err", err)
|
||||
}
|
||||
if err := sleepCtx(ctx, backoff.next()); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connectOnce performs one full connection lifecycle: dial → hello →
|
||||
// heartbeat loop + read loop → close. Returns when either side closes
|
||||
// the socket.
|
||||
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
|
||||
wsURL, err := buildWSURL(cfg.ServerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws agent: bad server url: %w", err)
|
||||
}
|
||||
|
||||
dialOpts := &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{
|
||||
"Authorization": []string{"Bearer " + cfg.AgentToken},
|
||||
},
|
||||
}
|
||||
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
|
||||
dialOpts.HTTPClient = &stdhttp.Client{
|
||||
Transport: &stdhttp.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
defer conn.CloseNow() //nolint:errcheck
|
||||
|
||||
// Send hello.
|
||||
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal hello: %w", err)
|
||||
}
|
||||
if err := writeEnv(ctx, conn, helloEnv); err != nil {
|
||||
return fmt.Errorf("write hello: %w", err)
|
||||
}
|
||||
slog.Info("ws agent connected", "server", wsURL)
|
||||
|
||||
// Heartbeat goroutine.
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
|
||||
defer cancelHeartbeat()
|
||||
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
|
||||
|
||||
// Read loop. A read error returns and closes the conn.
|
||||
for {
|
||||
mt, raw, err := conn.Read(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read: %w", err)
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
var env api.Envelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
slog.Warn("ws agent: bad envelope from server", "err", err)
|
||||
continue
|
||||
}
|
||||
if env.Type == api.MsgError {
|
||||
var ep api.ErrorPayload
|
||||
_ = env.UnmarshalPayload(&ep)
|
||||
slog.Error("ws agent: server reported error",
|
||||
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
|
||||
// protocol_too_old is fatal — keep retrying won't help.
|
||||
if ep.Code == api.ErrProtocolTooOld {
|
||||
return fmt.Errorf("protocol too old: %s", ep.Message)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if handle != nil {
|
||||
if err := handle(ctx, env); err != nil {
|
||||
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
|
||||
t := time.NewTicker(period)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
env, err := api.Marshal(api.MsgHeartbeat, "",
|
||||
api.HeartbeatPayload{SentAt: time.Now().UTC()})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := writeEnv(ctx, conn, env); err != nil {
|
||||
slog.Warn("ws agent: heartbeat write failed", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
|
||||
raw, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Write(ctx, websocket.MessageText, raw)
|
||||
}
|
||||
|
||||
func buildWSURL(serverURL string) (string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
case "http":
|
||||
u.Scheme = "ws"
|
||||
case "ws", "wss":
|
||||
// already correct
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// pinChecker returns a VerifyPeerCertificate callback that requires
|
||||
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
|
||||
// to the OS root verification (we don't replace it).
|
||||
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return errors.New("ws agent: no peer certs")
|
||||
}
|
||||
got := sha256Hex(rawCerts[0])
|
||||
if got != wantHex {
|
||||
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sha256Hex(b []byte) string {
|
||||
// avoid pulling in crypto/sha256 in this top-level file twice;
|
||||
// indirection through hex-encode is the classic shape.
|
||||
h := newSHA256()
|
||||
h.Write(b)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// ----- backoff -------------------------------------------------------
|
||||
|
||||
type backoff struct {
|
||||
cur, max time.Duration
|
||||
}
|
||||
|
||||
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
|
||||
|
||||
func (b *backoff) next() time.Duration {
|
||||
d := b.cur
|
||||
// 20% jitter, deterministic-enough randomness.
|
||||
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
|
||||
b.cur *= 2
|
||||
if b.cur > b.max {
|
||||
b.cur = b.max
|
||||
}
|
||||
return d + jitter
|
||||
}
|
||||
|
||||
func sleepCtx(ctx context.Context, d time.Duration) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(d):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// EnrollRequest is what we POST to /api/agents/enroll.
|
||||
type EnrollRequest struct {
|
||||
Token string `json:"token"`
|
||||
HostName string `json:"hostname"`
|
||||
OS api.HostOS `json:"os"`
|
||||
Arch api.HostArch `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
}
|
||||
|
||||
// EnrollResponse is what the server hands back.
|
||||
type EnrollResponse struct {
|
||||
HostID string `json:"host_id"`
|
||||
AgentToken string `json:"agent_token"`
|
||||
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
|
||||
}
|
||||
|
||||
// Enroll exchanges a one-time enrollment token for persistent agent
|
||||
// credentials. Called by the install script on first run.
|
||||
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: marshal: %w", err)
|
||||
}
|
||||
postURL := strings.TrimRight(serverURL, "/") + "/api/agents/enroll"
|
||||
|
||||
httpReq, err := stdhttp.NewRequestWithContext(ctx, stdhttp.MethodPost, postURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &stdhttp.Client{Timeout: 30 * time.Second}
|
||||
res, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: post: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
rawRes, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode != stdhttp.StatusCreated {
|
||||
return nil, fmt.Errorf("agent enroll: server returned %d: %s",
|
||||
res.StatusCode, rawRes)
|
||||
}
|
||||
var er EnrollResponse
|
||||
if err := json.Unmarshal(rawRes, &er); err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: parse response: %w", err)
|
||||
}
|
||||
if er.AgentToken == "" || er.HostID == "" {
|
||||
return nil, fmt.Errorf("agent enroll: incomplete response: %+v", er)
|
||||
}
|
||||
return &er, nil
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
)
|
||||
|
||||
func newSHA256() hash.Hash { return sha256.New() }
|
||||
Reference in New Issue
Block a user