phase 1: WS transport, enrollment, agent that hellos and heartbeats

Lands the protocol layer end-to-end: an agent can be enrolled
through the operator UI, store credentials, dial back to the server
over WS, complete the protocol_version handshake, and stay
connected with periodic heartbeats.

Server side:
- P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction,
  json envelope writer with a write mutex, reader, error envelopes.
- P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage
  (10s deadline, protocol_version checked against
  api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on
  reject), main read loop, defer hub register/unregister.
- P1-10 POST /api/agents/enroll consumes a one-time token, mints a
  persistent agent bearer (sha-256 stored), creates a host row.
- P1-10 POST /api/enrollment-tokens (operator, session-auth)
  issues a 1h one-time token.
- P1-11 hello upserts agent_version + restic_version +
  protocol_version on the host row, flips status to online.
- P1-12 heartbeat touches last_seen_at; background sweeper marks
  hosts offline after 90s without one.
- store: hosts table accessors, host_schedule_version,
  enrollment_tokens FK on consumed_host dropped (audit-only field;
  the token gets burned before the host row exists).

Agent side:
- P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml,
  atomic Save (tmp+fsync+rename), Enrolled() helper.
- P1-15 internal/agent/wsclient: dial with bearer + optional
  TLS cert pinning (sha-256 of leaf), exponential backoff with
  jitter (1s → 60s cap), heartbeat goroutine, fatal handling for
  ErrProtocolTooOld.
- P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo.
- P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version
  collection. restic detected by `restic version` parse; absent
  restic doesn't block startup.
- cmd/agent: -enroll-server / -enroll-token flags drive first-run
  enrollment then exit (so the install script can hand off to
  systemd to run the persistent service).

End-to-end smoke verified: bootstrap → login → issue token →
enroll → run agent → server logs `ws agent connected` with the
right host_id and protocol_version 1.

All tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:39:00 +01:00
parent df2c584b23
commit 9cc0caff1e
18 changed files with 1670 additions and 14 deletions
+119 -4
View File
@@ -2,32 +2,147 @@ package main
import ( import (
"context" "context"
"errors"
"flag" "flag"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/config"
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/sysinfo"
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/wsclient"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
) )
var version = "dev" var version = "dev"
func main() { func main() {
if err := run(); err != nil {
slog.Error("agent fatal", "err", err)
os.Exit(1)
}
}
func run() error {
configPath := flag.String("config", config.DefaultPath(), "path to agent.yaml")
enrollServer := flag.String("enroll-server", "", "server URL (used with -enroll-token to perform first-run enrollment)")
enrollToken := flag.String("enroll-token", "", "one-time enrollment token (operator copies this from the UI)")
showVersion := flag.Bool("version", false, "print version and exit") showVersion := flag.Bool("version", false, "print version and exit")
flag.Parse() flag.Parse()
if *showVersion { if *showVersion {
fmt.Println("restic-manager-agent", version) fmt.Println("restic-manager-agent", version)
return return nil
} }
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
slog.SetDefault(logger) slog.SetDefault(logger)
cfg, err := config.Load(*configPath)
if err != nil {
return fmt.Errorf("config: %w", err)
}
// Enrollment mode: agent was started with -enroll-server -enroll-token.
// On success we persist the credentials and exit (the install script
// then starts the agent service). Avoiding a long-running process here
// keeps the enrollment story restartable.
if *enrollToken != "" {
if *enrollServer == "" {
return errors.New("enrollment: -enroll-server is required with -enroll-token")
}
return doEnroll(*enrollServer, *enrollToken, cfg, version)
}
if !cfg.Enrolled() {
return fmt.Errorf("agent is not enrolled; run with -enroll-server and -enroll-token first (config %q)", *configPath)
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop() defer stop()
slog.Info("restic-manager agent starting", "version", version) snap, err := sysinfo.Collect(ctx, cfg.ResticPath)
<-ctx.Done() if err != nil {
slog.Info("shutting down") return fmt.Errorf("sysinfo: %w", err)
}
slog.Info("agent starting",
"version", version,
"host_id", cfg.HostID,
"server", cfg.ServerURL,
"restic_version", snap.ResticVersion,
"protocol_version", snap.ProtocolVersion,
)
wsCfg := wsclient.Config{
ServerURL: cfg.ServerURL,
AgentToken: cfg.AgentToken,
HostID: cfg.HostID,
CertPinSHA256: cfg.CertPinSHA256,
HelloPayload: api.HelloPayload{
ProtocolVersion: snap.ProtocolVersion,
AgentVersion: version,
ResticVersion: snap.ResticVersion,
Hostname: snap.Hostname,
OS: snap.OS,
Arch: snap.Arch,
},
}
if err := wsclient.Run(ctx, wsCfg, dispatch); err != nil {
return fmt.Errorf("ws run: %w", err)
}
slog.Info("agent shutting down")
return nil
}
// dispatch handles server-pushed envelopes. Phase 1's first slice
// just logs; P1-19/20/21 wire command.run to the runner.
func dispatch(_ context.Context, env api.Envelope) error {
switch env.Type {
case api.MsgCommandRun:
slog.Info("ws agent: command.run received (not yet implemented)", "id", env.ID)
case api.MsgCommandCancel:
slog.Info("ws agent: command.cancel received (not yet implemented)", "id", env.ID)
case api.MsgScheduleSet:
slog.Info("ws agent: schedule.set received (not yet implemented)", "id", env.ID)
case api.MsgConfigUpdate:
slog.Info("ws agent: config.update received (not yet implemented)", "id", env.ID)
case api.MsgAgentUpdateAvail:
slog.Info("ws agent: agent.update.available received (not yet implemented)", "id", env.ID)
default:
slog.Debug("ws agent: ignored message", "type", env.Type)
}
return nil
}
func doEnroll(serverURL, token string, cfg *config.Config, agentVersion string) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*1e9)
defer cancel()
snap, err := sysinfo.Collect(ctx, cfg.ResticPath)
if err != nil {
return fmt.Errorf("sysinfo: %w", err)
}
res, err := wsclient.Enroll(ctx, serverURL, wsclient.EnrollRequest{
Token: token,
HostName: snap.Hostname,
OS: snap.OS,
Arch: snap.Arch,
AgentVersion: agentVersion,
ResticVersion: snap.ResticVersion,
})
if err != nil {
return fmt.Errorf("enroll: %w", err)
}
cfg.ServerURL = serverURL
cfg.HostID = res.HostID
cfg.AgentToken = res.AgentToken
cfg.CertPinSHA256 = res.CertPinSHA256
if err := cfg.Save(); err != nil {
return fmt.Errorf("save config: %w", err)
}
fmt.Fprintf(os.Stderr, "enrolled as host %s on %s\n", res.HostID, serverURL)
return nil
} }
+19 -5
View File
@@ -14,8 +14,9 @@ import (
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store" "gitea.dcglab.co.uk/steve/restic-manager/internal/store"
) )
@@ -76,10 +77,13 @@ func run() error {
} }
defer func() { _ = st.Close() }() defer func() { _ = st.Close() }()
hub := ws.NewHub()
deps := rmhttp.Deps{ deps := rmhttp.Deps{
Cfg: cfg, Cfg: cfg,
Store: st, Store: st,
AEAD: aead, AEAD: aead,
Hub: hub,
} }
// First-run bootstrap: if the users table is empty, mint a one-time // First-run bootstrap: if the users table is empty, mint a one-time
@@ -117,21 +121,31 @@ func run() error {
errCh <- srv.Start() errCh <- srv.Start()
}() }()
// Background sweeper for expired sessions and enrollment tokens. // Background sweepers:
tick := time.NewTicker(15 * time.Minute) // - sessions/tokens purge: 15 min
defer tick.Stop() // - host offline-after-90s mark: every 30s (matches heartbeat
// cadence — agent sends every 30s, P1-12)
purgeTick := time.NewTicker(15 * time.Minute)
defer purgeTick.Stop()
offlineTick := time.NewTicker(30 * time.Second)
defer offlineTick.Stop()
go func() { go func() {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-tick.C: case <-purgeTick.C:
if n, err := st.PurgeExpiredSessions(ctx); err == nil && n > 0 { if n, err := st.PurgeExpiredSessions(ctx); err == nil && n > 0 {
slog.Info("purged expired sessions", "n", n) slog.Info("purged expired sessions", "n", n)
} }
if n, err := st.PurgeExpiredEnrollmentTokens(ctx); err == nil && n > 0 { if n, err := st.PurgeExpiredEnrollmentTokens(ctx); err == nil && n > 0 {
slog.Info("purged expired enrollment tokens", "n", n) slog.Info("purged expired enrollment tokens", "n", n)
} }
case <-offlineTick.C:
cutoff := time.Now().Add(-90 * time.Second)
if n, err := st.MarkHostsOfflineStale(ctx, cutoff); err == nil && n > 0 {
slog.Info("marked hosts offline (stale heartbeat)", "n", n)
}
} }
} }
}() }()
+1
View File
@@ -11,6 +11,7 @@ require (
) )
require ( require (
github.com/coder/websocket v1.8.14 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
+2
View File
@@ -1,3 +1,5 @@
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
+110
View File
@@ -0,0 +1,110 @@
// Package config loads the agent's persistent configuration. After
// enrollment, the file holds the bearer token + server URL; it is
// only ever written via Save (which replaces atomically).
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// Config is the on-disk shape of the agent's config file.
type Config struct {
// ServerURL is the base URL of the control plane, e.g.
// https://restic.lab.example. The agent appends /ws/agent and
// /api/agents/enroll.
ServerURL string `yaml:"server_url"`
// AgentToken is the bearer credential issued at enrollment.
// Empty means "not yet enrolled."
AgentToken string `yaml:"agent_token"`
// HostID is what the server thinks this host is.
HostID string `yaml:"host_id"`
// CertPinSHA256 (optional) is the SHA-256 of the server's TLS
// cert. When set, the agent refuses to connect to a server
// whose cert hash doesn't match.
CertPinSHA256 string `yaml:"cert_pin_sha256,omitempty"`
// ResticPath overrides the auto-detected restic binary path.
ResticPath string `yaml:"restic_path,omitempty"`
// path is the file we loaded from. Used by Save.
path string `yaml:"-"`
}
// DefaultPath returns the canonical config path for the current OS.
// Phase 1 ships Linux only; Windows path lives in the spec for P2.
func DefaultPath() string {
return "/etc/restic-manager/agent.yaml"
}
// Load reads and parses the config file at path. A missing file is
// returned as an empty Config (not an error) — first-run agents
// haven't been enrolled yet.
func Load(path string) (*Config, error) {
c := &Config{path: path}
body, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return c, nil
}
return nil, fmt.Errorf("agent config: read %q: %w", path, err)
}
if err := yaml.Unmarshal(body, c); err != nil {
return nil, fmt.Errorf("agent config: parse %q: %w", path, err)
}
c.path = path
return c, nil
}
// Save writes the config back atomically: write to <path>.tmp, fsync,
// rename. A crash mid-write either leaves the old file or the new one,
// never a half-written one.
func (c *Config) Save() error {
if c.path == "" {
return fmt.Errorf("agent config: no path set")
}
dir := filepath.Dir(c.path)
if err := os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("agent config: mkdir %q: %w", dir, err)
}
body, err := yaml.Marshal(c)
if err != nil {
return fmt.Errorf("agent config: marshal: %w", err)
}
tmp := c.path + ".tmp"
f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("agent config: create tmp: %w", err)
}
if _, err := f.Write(body); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("agent config: write tmp: %w", err)
}
if err := f.Sync(); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("agent config: fsync tmp: %w", err)
}
if err := f.Close(); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("agent config: close tmp: %w", err)
}
if err := os.Rename(tmp, c.path); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("agent config: rename: %w", err)
}
return nil
}
// Enrolled reports whether the agent has finished enrollment.
func (c *Config) Enrolled() bool {
return c.AgentToken != "" && c.HostID != "" && c.ServerURL != ""
}
+78
View File
@@ -0,0 +1,78 @@
// Package sysinfo collects host metadata at agent startup: OS, arch,
// hostname, restic version. The agent sends this in `hello` so the
// server's Host row stays current.
package sysinfo
import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"strings"
"time"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Snapshot is the bundle of metadata reported in `hello`.
type Snapshot struct {
Hostname string
OS api.HostOS
Arch api.HostArch
ResticVersion string
ProtocolVersion int
BootTime time.Time
}
// Collect probes the running host. resticPath, if non-empty,
// overrides PATH lookup.
func Collect(ctx context.Context, resticPath string) (Snapshot, error) {
hn, err := os.Hostname()
if err != nil {
return Snapshot{}, fmt.Errorf("sysinfo: hostname: %w", err)
}
osTag := api.HostOS(runtime.GOOS)
archTag := api.HostArch(runtime.GOARCH)
resticVer, _ := detectResticVersion(ctx, resticPath) // empty on failure is fine
return Snapshot{
Hostname: hn,
OS: osTag,
Arch: archTag,
ResticVersion: resticVer,
ProtocolVersion: api.CurrentProtocolVersion,
}, nil
}
// detectResticVersion runs `restic version` and parses the first line.
// Output looks like:
// restic 0.17.1 compiled with go1.22.5 on linux/amd64
// Returns the version token (e.g. "0.17.1") or "" if restic isn't
// found. We never block startup on a missing restic — the operator
// might not have installed it yet, and the agent should still be
// able to connect and report.
func detectResticVersion(ctx context.Context, override string) (string, error) {
bin := override
if bin == "" {
var err error
bin, err = exec.LookPath("restic")
if err != nil {
return "", err
}
}
versionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
out, err := exec.CommandContext(versionCtx, bin, "version").Output()
if err != nil {
return "", err
}
first := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0]
parts := strings.Fields(first)
if len(parts) >= 2 && parts[0] == "restic" {
return parts[1], nil
}
return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first)
}
+246
View File
@@ -0,0 +1,246 @@
// Package wsclient is the agent's outbound WebSocket connection to
// the control plane: dial with bearer auth, perform the hello
// handshake, send heartbeats, dispatch server-pushed commands.
//
// The Run loop is a forever-loop with exponential backoff on dial
// failures, capped at 60s. Disconnected agents keep retrying.
package wsclient
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand"
stdhttp "net/http"
"net/url"
"strings"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Config holds the agent's connection settings.
type Config struct {
ServerURL string
AgentToken string
HostID string
CertPinSHA256 string // hex; empty disables pinning
HeartbeatPeriod time.Duration
HelloPayload api.HelloPayload
}
// Handler is invoked for every server-sent message. The agent's main
// program supplies one that knows how to dispatch command.run etc.
// to the runner package.
type Handler func(ctx context.Context, env api.Envelope) error
// Run keeps the agent connected indefinitely. Returns when ctx is
// cancelled. Errors during a single connection attempt are logged and
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
func Run(ctx context.Context, cfg Config, handle Handler) error {
if cfg.HeartbeatPeriod <= 0 {
cfg.HeartbeatPeriod = 30 * time.Second
}
backoff := newBackoff(time.Second, 60*time.Second)
for {
err := connectOnce(ctx, cfg, handle)
if errors.Is(err, context.Canceled) {
return nil
}
if err != nil {
slog.Warn("ws agent disconnect", "err", err)
}
if err := sleepCtx(ctx, backoff.next()); err != nil {
return nil
}
}
}
// connectOnce performs one full connection lifecycle: dial → hello →
// heartbeat loop + read loop → close. Returns when either side closes
// the socket.
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
wsURL, err := buildWSURL(cfg.ServerURL)
if err != nil {
return fmt.Errorf("ws agent: bad server url: %w", err)
}
dialOpts := &websocket.DialOptions{
HTTPHeader: stdhttp.Header{
"Authorization": []string{"Bearer " + cfg.AgentToken},
},
}
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
dialOpts.HTTPClient = &stdhttp.Client{
Transport: &stdhttp.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
},
},
}
}
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts)
cancel()
if err != nil {
return fmt.Errorf("dial: %w", err)
}
defer conn.CloseNow() //nolint:errcheck
// Send hello.
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
if err != nil {
return fmt.Errorf("marshal hello: %w", err)
}
if err := writeEnv(ctx, conn, helloEnv); err != nil {
return fmt.Errorf("write hello: %w", err)
}
slog.Info("ws agent connected", "server", wsURL)
// Heartbeat goroutine.
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
defer cancelHeartbeat()
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
// Read loop. A read error returns and closes the conn.
for {
mt, raw, err := conn.Read(ctx)
if err != nil {
return fmt.Errorf("read: %w", err)
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
slog.Warn("ws agent: bad envelope from server", "err", err)
continue
}
if env.Type == api.MsgError {
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Error("ws agent: server reported error",
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
// protocol_too_old is fatal — keep retrying won't help.
if ep.Code == api.ErrProtocolTooOld {
return fmt.Errorf("protocol too old: %s", ep.Message)
}
continue
}
if handle != nil {
if err := handle(ctx, env); err != nil {
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
}
}
}
}
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
t := time.NewTicker(period)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
env, err := api.Marshal(api.MsgHeartbeat, "",
api.HeartbeatPayload{SentAt: time.Now().UTC()})
if err != nil {
continue
}
if err := writeEnv(ctx, conn, env); err != nil {
slog.Warn("ws agent: heartbeat write failed", "err", err)
return
}
}
}
}
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
raw, err := json.Marshal(env)
if err != nil {
return err
}
return conn.Write(ctx, websocket.MessageText, raw)
}
func buildWSURL(serverURL string) (string, error) {
u, err := url.Parse(serverURL)
if err != nil {
return "", err
}
switch u.Scheme {
case "https":
u.Scheme = "wss"
case "http":
u.Scheme = "ws"
case "ws", "wss":
// already correct
default:
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
}
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
return u.String(), nil
}
// pinChecker returns a VerifyPeerCertificate callback that requires
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
// to the OS root verification (we don't replace it).
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("ws agent: no peer certs")
}
got := sha256Hex(rawCerts[0])
if got != wantHex {
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
}
return nil
}
}
func sha256Hex(b []byte) string {
// avoid pulling in crypto/sha256 in this top-level file twice;
// indirection through hex-encode is the classic shape.
h := newSHA256()
h.Write(b)
return hex.EncodeToString(h.Sum(nil))
}
// ----- backoff -------------------------------------------------------
type backoff struct {
cur, max time.Duration
}
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
func (b *backoff) next() time.Duration {
d := b.cur
// 20% jitter, deterministic-enough randomness.
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
b.cur *= 2
if b.cur > b.max {
b.cur = b.max
}
return d + jitter
}
func sleepCtx(ctx context.Context, d time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(d):
return nil
}
}
+67
View File
@@ -0,0 +1,67 @@
package wsclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
stdhttp "net/http"
"strings"
"time"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// EnrollRequest is what we POST to /api/agents/enroll.
type EnrollRequest struct {
Token string `json:"token"`
HostName string `json:"hostname"`
OS api.HostOS `json:"os"`
Arch api.HostArch `json:"arch"`
AgentVersion string `json:"agent_version"`
ResticVersion string `json:"restic_version"`
}
// EnrollResponse is what the server hands back.
type EnrollResponse struct {
HostID string `json:"host_id"`
AgentToken string `json:"agent_token"`
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
}
// Enroll exchanges a one-time enrollment token for persistent agent
// credentials. Called by the install script on first run.
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("agent enroll: marshal: %w", err)
}
postURL := strings.TrimRight(serverURL, "/") + "/api/agents/enroll"
httpReq, err := stdhttp.NewRequestWithContext(ctx, stdhttp.MethodPost, postURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("agent enroll: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
client := &stdhttp.Client{Timeout: 30 * time.Second}
res, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("agent enroll: post: %w", err)
}
defer res.Body.Close()
rawRes, _ := io.ReadAll(res.Body)
if res.StatusCode != stdhttp.StatusCreated {
return nil, fmt.Errorf("agent enroll: server returned %d: %s",
res.StatusCode, rawRes)
}
var er EnrollResponse
if err := json.Unmarshal(rawRes, &er); err != nil {
return nil, fmt.Errorf("agent enroll: parse response: %w", err)
}
if er.AgentToken == "" || er.HostID == "" {
return nil, fmt.Errorf("agent enroll: incomplete response: %+v", er)
}
return &er, nil
}
+8
View File
@@ -0,0 +1,8 @@
package wsclient
import (
"crypto/sha256"
"hash"
)
func newSHA256() hash.Hash { return sha256.New() }
+165
View File
@@ -0,0 +1,165 @@
package http
import (
"encoding/json"
stdhttp "net/http"
"strings"
"time"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// enrollRequest is the body posted by the agent installer. The token
// was issued by the operator via the UI ("Add host" → P1-27); the
// host metadata comes from the agent's own sysinfo collection.
type enrollRequest struct {
Token string `json:"token"`
HostName string `json:"hostname"`
OS api.HostOS `json:"os"`
Arch api.HostArch `json:"arch"`
AgentVersion string `json:"agent_version"`
ResticVersion string `json:"restic_version"`
}
// enrollResponse hands the agent the credentials it'll use forever.
// AgentToken is shown exactly once; the server stores its hash.
// CertPinSHA256 is the SHA-256 of the server's certificate, base64;
// the agent pins this on every reconnect so a stolen DB at the
// control plane can't be replayed against an attacker's TLS endpoint.
type enrollResponse struct {
HostID string `json:"host_id"`
AgentToken string `json:"agent_token"`
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
}
// enrollOperatorRequest creates a one-time enrollment token for an
// operator who is about to install an agent. Authenticated UI route.
type enrollOperatorRequest struct {
HostName string `json:"hostname"`
Tags []string `json:"tags,omitempty"`
}
type enrollOperatorResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
}
// handleAgentEnroll consumes a one-time token, persists a Host row,
// and returns persistent agent credentials. Open endpoint (no
// session) — the token is the credential.
func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request) {
var req enrollRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
return
}
if req.Token == "" || req.HostName == "" || req.OS == "" || req.Arch == "" {
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
"token, hostname, os, arch all required")
return
}
hostID := ulid.Make().String()
// Atomically: validate + consume token, then create the host.
// We do these in two statements; if create-host fails, the token
// is already burned. That's acceptable — operator just regens.
tokHash := auth.HashToken(req.Token)
if err := s.deps.Store.ConsumeEnrollmentToken(r.Context(), tokHash, hostID); err != nil {
writeJSONError(w, stdhttp.StatusUnauthorized, "invalid_token",
"token unknown, expired, or already used")
return
}
// Mint the persistent agent bearer.
agentToken, err := auth.NewToken()
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
return
}
host := store.Host{
ID: hostID,
Name: strings.TrimSpace(req.HostName),
OS: string(req.OS),
Arch: string(req.Arch),
AgentVersion: req.AgentVersion,
ResticVersion: req.ResticVersion,
EnrolledAt: time.Now().UTC(),
}
if err := s.deps.Store.CreateHost(r.Context(), host,
auth.HashToken(agentToken), ""); err != nil {
writeJSONError(w, stdhttp.StatusConflict, "host_exists", err.Error())
return
}
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(),
Actor: "system",
Action: "host.enrolled",
TargetKind: ptr("host"),
TargetID: &hostID,
TS: host.EnrolledAt,
})
writeJSON(w, stdhttp.StatusCreated, enrollResponse{
HostID: hostID,
AgentToken: agentToken,
// CertPinSHA256 is populated by a TLS-aware future revision.
// For now (HTTP-or-TLS-by-Caddy) we leave it empty and rely
// on the agent trusting its OS root store.
})
}
// handleCreateEnrollmentToken (operator-facing) — generates a
// short-lived token for a new host. Authenticated; admin/operator only.
//
// TODO: gate by authn middleware once login session lookup lands.
// For Phase 1's first slice, we accept the bootstrap-shipped admin
// session cookie and trust it, validating the cookie via store.
func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhttp.Request) {
if !s.authedUser(r) {
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
return
}
var req enrollOperatorRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
return
}
token, err := auth.NewToken()
if err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
return
}
const ttl = time.Hour
if err := s.deps.Store.CreateEnrollmentToken(r.Context(), auth.HashToken(token), ttl); err != nil {
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
return
}
writeJSON(w, stdhttp.StatusCreated, enrollOperatorResponse{
Token: token,
ExpiresAt: time.Now().Add(ttl).UTC(),
})
}
// authedUser returns true iff the request carries a valid session
// cookie. Minimal stub for now; full RBAC middleware lands with
// P4-03.
func (s *Server) authedUser(r *stdhttp.Request) bool {
c, err := r.Cookie(sessionCookieName)
if err != nil {
return false
}
_, err = s.deps.Store.LookupSession(r.Context(), auth.HashToken(c.Value))
return err == nil
}
func ptr(s string) *string { return &s }
+118
View File
@@ -0,0 +1,118 @@
package http
import (
"bytes"
"context"
"encoding/json"
"io"
stdhttp "net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// newTestServerWithHub mirrors newTestServer but plugs in a real
// ws.Hub so /ws/agent is available.
func newTestServerWithHub(t *testing.T) (*Server, string, *store.Store) {
t.Helper()
dir := t.TempDir()
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("store: %v", err)
}
t.Cleanup(func() { _ = st.Close() })
keyPath := filepath.Join(dir, "secret.key")
_ = crypto.GenerateKeyFile(keyPath)
key, _ := crypto.LoadKeyFromFile(keyPath)
aead, _ := crypto.NewAEAD(key)
deps := Deps{
Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath},
Store: st,
AEAD: aead,
Hub: ws.NewHub(),
}
s := New(deps)
ts := httptest.NewServer(s.srv.Handler)
t.Cleanup(ts.Close)
return s, ts.URL, st
}
func TestEnrollmentBadToken(t *testing.T) {
t.Parallel()
_, url, _ := newTestServerWithHub(t)
body, _ := json.Marshal(enrollRequest{
Token: "no-such-token", HostName: "host1",
OS: api.OSLinux, Arch: api.ArchAmd64,
AgentVersion: "0.1", ResticVersion: "0.17",
})
res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
if err != nil {
t.Fatalf("post: %v", err)
}
defer res.Body.Close()
if res.StatusCode != stdhttp.StatusUnauthorized {
t.Errorf("status: %d", res.StatusCode)
}
}
func TestEnrollmentHappyPath(t *testing.T) {
t.Parallel()
_, url, st := newTestServerWithHub(t)
// Issue a token directly via the store (skipping the operator UI).
rawToken, _ := auth.NewToken()
if err := st.CreateEnrollmentToken(context.Background(),
auth.HashToken(rawToken), 5*time.Minute); err != nil {
t.Fatalf("issue: %v", err)
}
body, _ := json.Marshal(enrollRequest{
Token: rawToken, HostName: "test-host",
OS: api.OSLinux, Arch: api.ArchAmd64,
AgentVersion: "0.1", ResticVersion: "0.17",
})
res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
if err != nil {
t.Fatalf("post: %v", err)
}
defer res.Body.Close()
if res.StatusCode != stdhttp.StatusCreated {
buf, _ := io.ReadAll(res.Body)
t.Fatalf("status %d: %s", res.StatusCode, buf)
}
var er enrollResponse
if err := json.NewDecoder(res.Body).Decode(&er); err != nil {
t.Fatalf("decode: %v", err)
}
if er.HostID == "" || er.AgentToken == "" {
t.Errorf("missing fields in response: %+v", er)
}
// Token must not be reusable.
res2, _ := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
defer res2.Body.Close()
if res2.StatusCode != stdhttp.StatusUnauthorized {
t.Errorf("re-enrollment with same token should fail, got %d", res2.StatusCode)
}
// Host row exists with matching agent_token_hash.
got, err := st.LookupHostByAgentToken(context.Background(), auth.HashToken(er.AgentToken))
if err != nil {
t.Fatalf("lookup by token: %v", err)
}
if got.Name != "test-host" || got.OS != "linux" {
t.Errorf("host fields: %+v", got)
}
}
+18
View File
@@ -15,6 +15,7 @@ import (
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store" "gitea.dcglab.co.uk/steve/restic-manager/internal/store"
) )
@@ -24,6 +25,7 @@ type Deps struct {
Cfg config.Config Cfg config.Config
Store *store.Store Store *store.Store
AEAD *crypto.AEAD AEAD *crypto.AEAD
Hub *ws.Hub
// BootstrapToken (optional, populated only on first run) is the raw // BootstrapToken (optional, populated only on first run) is the raw
// admin-bootstrap token printed in the server logs. While set, the // admin-bootstrap token printed in the server logs. While set, the
// /bootstrap endpoint accepts it to create the first admin user. // /bootstrap endpoint accepts it to create the first admin user.
@@ -73,8 +75,24 @@ func (s *Server) routes(r chi.Router) {
r.Post("/auth/login", s.handleLogin) r.Post("/auth/login", s.handleLogin)
r.Post("/auth/logout", s.handleLogout) r.Post("/auth/logout", s.handleLogout)
r.Post("/bootstrap", s.handleBootstrap) r.Post("/bootstrap", s.handleBootstrap)
// Agent enrollment (open endpoint — token is the credential).
r.Post("/agents/enroll", s.handleAgentEnroll)
// Operator → server (authenticated). Spec.md §6.1's
// /hosts/{id}/enrollment-token (regenerate) lands when the
// host page can call it; for now just the create endpoint.
r.Post("/enrollment-tokens", s.handleCreateEnrollmentToken)
}) })
// Agent ↔ server WebSocket. Bearer-authenticated inside the handler.
if s.deps.Hub != nil {
r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{
Hub: s.deps.Hub,
Store: s.deps.Store,
}))
}
// UI handlers will hang off / — Phase 1 will add them. // UI handlers will hang off / — Phase 1 will add them.
r.Get("/", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) { r.Get("/", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
_, _ = fmt.Fprint(w, "restic-manager — UI not yet implemented") _, _ = fmt.Fprint(w, "restic-manager — UI not yet implemented")
-3
View File
@@ -1,3 +0,0 @@
// Package ws hosts the WebSocket transport for agent ↔ server and the
// browser-facing live job log stream.
package ws
+183
View File
@@ -0,0 +1,183 @@
package ws
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
stdhttp "net/http"
"strings"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// HandlerDeps is the set of collaborators the agent WS handler needs.
type HandlerDeps struct {
Hub *Hub
Store *store.Store
}
// AgentHandler is the http.Handler that owns /ws/agent. Agents
// authenticate with `Authorization: Bearer <token>` (issued at
// enrollment) before the WS upgrade.
//
// Lifecycle:
// 1. Bearer token resolves to a Host row.
// 2. Upgrade.
// 3. First message must be `hello`; protocol_version checked here.
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
// host row; job/log/repo messages forward to the relevant
// handlers (TODO: lands with P1-18 onward).
// 5. On Read error or context cancel, mark host offline, unregister
// from the hub.
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
host, ok := authenticateAgent(r, deps.Store)
if !ok {
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
})
if err != nil {
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
return
}
c := NewConn(host.ID, conn)
// Keep agents alive across NAT boxes; coder/websocket
// auto-pings under the hood when configured. The default 60s
// works fine for a 30s heartbeat cadence.
runAgentLoop(r.Context(), c, host.ID, deps)
})
}
// authenticateAgent returns the host that owns the bearer token in
// the request, or (nil, false) if anything is amiss. The same
// "false" path is used for missing header, malformed header, unknown
// token — no information leak about why.
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
hdr := r.Header.Get("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(hdr, prefix) {
return nil, false
}
token := strings.TrimPrefix(hdr, prefix)
if token == "" {
return nil, false
}
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
if err != nil {
return nil, false
}
return h, true
}
// runAgentLoop is the per-connection driver. Returns when the socket
// is closed for any reason. It owns the hub registration: register on
// hello acceptance, unregister on exit.
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
// Stage 1: hello (with a tight deadline).
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
hello, err := c.Read(helloCtx)
cancel()
if err != nil {
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
_ = c.Close()
return
}
if hello.Type != api.MsgHello {
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
return
}
var helloPayload api.HelloPayload
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
return
}
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
c.SendError(ctx, api.ErrProtocolTooOld,
fmt.Sprintf("agent protocol_version %d below minimum %d",
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
"https://restic-manager.example/docs/upgrade")
return
}
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
// Forward-compat is fine — newer agents talking to older
// servers should accept their lower version. Just log it.
slog.Info("ws agent newer than server",
"host_id", hostID,
"agent_proto", helloPayload.ProtocolVersion,
"server_proto", api.CurrentProtocolVersion)
}
now := time.Now().UTC()
if err := deps.Store.MarkHostHello(ctx, hostID,
helloPayload.AgentVersion, helloPayload.ResticVersion,
helloPayload.ProtocolVersion, now); err != nil {
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
}
deps.Hub.Register(hostID, c)
defer deps.Hub.Unregister(hostID, c)
defer func() { _ = c.Close() }()
slog.Info("ws agent connected",
"host_id", hostID,
"agent_version", helloPayload.AgentVersion,
"protocol_version", helloPayload.ProtocolVersion)
// Stage 2: main read loop.
for {
env, err := c.Read(ctx)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
}
return
}
dispatchAgentMessage(ctx, c, hostID, env, deps)
}
}
// dispatchAgentMessage routes a single envelope to its handler. Only
// hello + heartbeat are wired up in Phase 1's first slice; the rest
// land with P1-18+ (jobs) and P2 (schedules).
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
switch env.Type {
case api.MsgHeartbeat:
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
case api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished,
api.MsgLogStream, api.MsgSnapshotsRpt, api.MsgRepoStats,
api.MsgScheduleAck, api.MsgCommandResult:
// TODO(P1-18+): persist + fan out to subscribed browsers.
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
case api.MsgError:
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Warn("ws agent reported error", "host_id", hostID,
"code", string(ep.Code), "message", ep.Message)
default:
slog.Warn("ws unknown message type from agent",
"type", env.Type, "host_id", hostID)
}
}
// MinHeartbeatInterval is a sanity floor — any agent reporting
// heartbeats more often than this is misbehaving. (Spec says 30s.)
const MinHeartbeatInterval = 5 * time.Second
// suppress unused-import false-positives if json drops out later
var _ = json.Marshal
+145
View File
@@ -0,0 +1,145 @@
// Package ws hosts the WebSocket transport for agent ↔ server. The
// Hub tracks one active connection per host id; subsequent connections
// from the same host evict the prior one (last-write-wins).
package ws
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Hub owns the live agent connections and routes messages.
type Hub struct {
mu sync.RWMutex
conns map[string]*Conn // hostID → conn
}
// NewHub returns an empty hub.
func NewHub() *Hub {
return &Hub{conns: make(map[string]*Conn)}
}
// Conn is one agent WS connection. Send is safe for concurrent use;
// Read is single-reader (the connection's run loop).
type Conn struct {
HostID string
c *websocket.Conn
writeMu sync.Mutex
}
// Register installs c as the canonical connection for hostID. Any
// previous connection for that host is closed.
func (h *Hub) Register(hostID string, c *Conn) {
h.mu.Lock()
if prev, ok := h.conns[hostID]; ok {
// Best-effort close — a stuck old socket shouldn't block new one.
go func(old *Conn) {
_ = old.c.Close(websocket.StatusPolicyViolation, "superseded")
}(prev)
}
h.conns[hostID] = c
h.mu.Unlock()
}
// Unregister removes c iff it is still the canonical conn (a race
// where a newer conn already replaced it must not unregister it).
func (h *Hub) Unregister(hostID string, c *Conn) {
h.mu.Lock()
if cur, ok := h.conns[hostID]; ok && cur == c {
delete(h.conns, hostID)
}
h.mu.Unlock()
}
// Send delivers an envelope to the host if connected. Returns an error
// if the host is offline; caller may queue the message for later.
func (h *Hub) Send(ctx context.Context, hostID string, env api.Envelope) error {
h.mu.RLock()
c, ok := h.conns[hostID]
h.mu.RUnlock()
if !ok {
return fmt.Errorf("ws: host %q is offline", hostID)
}
return c.Send(ctx, env)
}
// Connected reports whether hostID has an active connection.
func (h *Hub) Connected(hostID string) bool {
h.mu.RLock()
_, ok := h.conns[hostID]
h.mu.RUnlock()
return ok
}
// ----- Conn methods --------------------------------------------------
// NewConn wraps a freshly-accepted websocket for a given hostID.
func NewConn(hostID string, c *websocket.Conn) *Conn {
return &Conn{HostID: hostID, c: c}
}
// Send writes an envelope as a JSON text message. Concurrent calls
// are serialised; the underlying socket is not safe for parallel
// writers.
func (c *Conn) Send(ctx context.Context, env api.Envelope) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
raw, err := json.Marshal(env)
if err != nil {
return fmt.Errorf("ws: marshal envelope: %w", err)
}
return c.c.Write(ctx, websocket.MessageText, raw)
}
// SendError writes an error envelope and closes the socket. Used by
// the hello handshake when an agent is rejected.
func (c *Conn) SendError(ctx context.Context, code api.ErrorCode, msg, helpURL string) {
env, err := api.Marshal(api.MsgError, "", api.ErrorPayload{
Code: code, Message: msg, HelpURL: helpURL,
})
if err == nil {
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
_ = c.Send(writeCtx, env)
}
_ = c.c.Close(websocket.StatusPolicyViolation, string(code))
}
// Close shuts the socket down with a normal-closure status code.
func (c *Conn) Close() error {
return c.c.Close(websocket.StatusNormalClosure, "")
}
// Read pulls the next JSON envelope off the wire. The caller's
// context controls cancellation and timeouts (e.g. read deadlines).
func (c *Conn) Read(ctx context.Context) (api.Envelope, error) {
mt, raw, err := c.c.Read(ctx)
if err != nil {
return api.Envelope{}, err
}
if mt != websocket.MessageText {
return api.Envelope{}, errors.New("ws: expected text frame")
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
return api.Envelope{}, fmt.Errorf("ws: unmarshal envelope: %w", err)
}
return env, nil
}
// ----- helpers -------------------------------------------------------
// LogValue emits a slog-friendly representation of a Conn.
func (c *Conn) LogValue() slog.Value {
return slog.GroupValue(slog.String("host_id", c.HostID))
}
+181
View File
@@ -0,0 +1,181 @@
package ws
import (
"context"
"encoding/json"
stdhttp "net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// setupTestHub spins up a Server that exposes only /ws/agent against
// a fresh sqlite store with one pre-enrolled host. Returns the URL,
// the agent's bearer token, and the host ID.
func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) {
t.Helper()
dir := t.TempDir()
var err error
st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("store: %v", err)
}
t.Cleanup(func() { _ = st.Close() })
hub = NewHub()
mux := stdhttp.NewServeMux()
mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st}))
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
// Pre-enroll a host directly via store (skipping HTTP).
hostID = "01HJ8K70000000000000000000"
token, _ = auth.NewToken()
now := time.Now().UTC()
if err := st.CreateHost(context.Background(), store.Host{
ID: hostID, Name: "h1", OS: "linux", Arch: "amd64",
EnrolledAt: now,
}, auth.HashToken(token), ""); err != nil {
t.Fatalf("enroll: %v", err)
}
url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent"
return
}
func TestWSHelloAndHeartbeat(t *testing.T) {
t.Parallel()
url, token, hostID, st, hub := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
// Send hello.
hello := api.HelloPayload{
ProtocolVersion: api.CurrentProtocolVersion,
AgentVersion: "0.1.0",
ResticVersion: "0.17.1",
Hostname: "h1",
OS: api.OSLinux,
Arch: api.ArchAmd64,
}
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
t.Fatalf("write hello: %v", err)
}
// Wait for the server to register us (registration happens after
// the hello-handler returns; give it up to 1s).
deadline := time.Now().Add(time.Second)
for !hub.Connected(hostID) && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
if !hub.Connected(hostID) {
t.Fatal("host did not register on hub after hello")
}
// Verify host row was marked online + has populated metadata.
h, err := st.GetHost(context.Background(), hostID)
if err != nil {
t.Fatalf("get host: %v", err)
}
if h.Status != "online" || h.AgentVersion != "0.1.0" {
t.Errorf("host after hello: %+v", h)
}
// Send a heartbeat — server should touch last_seen.
hb := api.HeartbeatPayload{SentAt: time.Now().UTC()}
env, _ = api.Marshal(api.MsgHeartbeat, "", hb)
raw, _ = json.Marshal(env)
preTouch := h.LastSeenAt
_ = c.Write(ctx, websocket.MessageText, raw)
// Wait briefly for server to process.
deadline = time.Now().Add(time.Second)
for time.Now().Before(deadline) {
h2, _ := st.GetHost(context.Background(), hostID)
if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Error("heartbeat did not update last_seen_at")
}
func TestWSRejectsOldProtocol(t *testing.T) {
t.Parallel()
url, token, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
hello := api.HelloPayload{ProtocolVersion: 0} // below minimum
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
_ = c.Write(ctx, websocket.MessageText, raw)
// Server should send an error envelope, then close.
mt, body, err := c.Read(ctx)
if err != nil {
t.Fatalf("read: %v", err)
}
if mt != websocket.MessageText {
t.Fatalf("frame type: %v", mt)
}
var got api.Envelope
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Type != api.MsgError {
t.Errorf("expected error envelope, got %q", got.Type)
}
var ep api.ErrorPayload
_ = got.UnmarshalPayload(&ep)
if ep.Code != api.ErrProtocolTooOld {
t.Errorf("error code: %q", ep.Code)
}
}
func TestWSRejectsBadToken(t *testing.T) {
t.Parallel()
url, _, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}},
})
if err == nil {
t.Fatal("dial should fail")
}
if res == nil || res.StatusCode != stdhttp.StatusUnauthorized {
if res != nil {
t.Errorf("status: %d", res.StatusCode)
}
}
}
+205
View File
@@ -0,0 +1,205 @@
package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
)
// CreateHost inserts a new host row. Used by the enrollment flow.
// The caller has already minted the host id and hashed the agent
// bearer token.
func (s *Store) CreateHost(ctx context.Context, h Host, agentTokenHash, certPinSHA256 string) error {
tags, err := json.Marshal(h.Tags)
if err != nil {
return fmt.Errorf("store: marshal tags: %w", err)
}
_, err = s.db.ExecContext(ctx,
`INSERT INTO hosts (
id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, status, tags,
agent_token_hash, cert_pin_sha256
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'offline', ?, ?, ?)`,
h.ID, h.Name, h.OS, h.Arch,
h.AgentVersion, h.ResticVersion, h.ProtocolVersion,
h.EnrolledAt.UTC().Format(time.RFC3339Nano),
string(tags),
agentTokenHash, certPinSHA256)
if err != nil {
return fmt.Errorf("store: create host: %w", err)
}
return nil
}
// LookupHostByAgentToken resolves a hashed agent bearer token to the
// host it belongs to. Returns ErrNotFound on miss.
func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*Host, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version
FROM hosts WHERE agent_token_hash = ?`,
tokenHash)
return scanHost(row)
}
// GetHost returns a host by ID. Returns ErrNotFound on miss.
func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version
FROM hosts WHERE id = ?`, id)
return scanHost(row)
}
// MarkHostHello updates the host row with metadata received in the
// agent's hello message and flips status to 'online'.
func (s *Store) MarkHostHello(ctx context.Context, id string, agentVersion, resticVersion string, protoVersion int, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET agent_version = ?, restic_version = ?, protocol_version = ?,
last_seen_at = ?, status = 'online'
WHERE id = ?`,
agentVersion, resticVersion, protoVersion,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: mark hello: %w", err)
}
return nil
}
// TouchHost updates last_seen_at on heartbeat, leaving status alone if
// already online (the offline-marker is a separate sweep).
func (s *Store) TouchHost(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET last_seen_at = ?,
status = CASE WHEN status = 'offline' THEN 'online' ELSE status END
WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: touch host: %w", err)
}
return nil
}
// MarkHostsOfflineStale flips any host that hasn't been seen since
// before `cutoff` from 'online' to 'offline'. Returns the number of
// rows affected so the caller can log non-zero events.
func (s *Store) MarkHostsOfflineStale(ctx context.Context, cutoff time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET status = 'offline'
WHERE status = 'online'
AND (last_seen_at IS NULL OR last_seen_at < ?)`,
cutoff.UTC().Format(time.RFC3339Nano))
if err != nil {
return 0, fmt.Errorf("store: mark offline: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
// ListHosts returns every host. Phase 1 callers fit a small fleet in
// memory; pagination lands when it matters.
func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version
FROM hosts ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("store: list hosts: %w", err)
}
defer rows.Close()
var out []Host
for rows.Next() {
h, err := scanHostRow(rows)
if err != nil {
return nil, err
}
out = append(out, *h)
}
return out, rows.Err()
}
// ----- scan helpers --------------------------------------------------
type hostScanner interface {
Scan(dest ...any) error
}
func scanHost(row *sql.Row) (*Host, error) {
h, err := scanHostRow(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return h, err
}
func scanHostRow(s hostScanner) (*Host, error) {
var h Host
var (
lastSeen, lastBackupAt sql.NullString
repoID, currentJob, lastBkSt sql.NullString
enrolled string
tags string
)
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
&currentJob, &lastBackupAt, &lastBkSt,
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
&h.AppliedScheduleVersion)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("store: scan host: %w", err)
}
t, err := time.Parse(time.RFC3339Nano, enrolled)
if err != nil {
return nil, fmt.Errorf("store: parse enrolled_at: %w", err)
}
h.EnrolledAt = t
if lastSeen.Valid {
t, err := time.Parse(time.RFC3339Nano, lastSeen.String)
if err != nil {
return nil, fmt.Errorf("store: parse last_seen_at: %w", err)
}
h.LastSeenAt = &t
}
if lastBackupAt.Valid {
t, err := time.Parse(time.RFC3339Nano, lastBackupAt.String)
if err != nil {
return nil, fmt.Errorf("store: parse last_backup_at: %w", err)
}
h.LastBackupAt = &t
}
if repoID.Valid {
s := repoID.String
h.RepoID = &s
}
if currentJob.Valid {
s := currentJob.String
h.CurrentJobID = &s
}
if lastBkSt.Valid {
s := lastBkSt.String
h.LastBackupStatus = &s
}
if tags != "" {
_ = json.Unmarshal([]byte(tags), &h.Tags)
}
return &h, nil
}
+5 -2
View File
@@ -92,12 +92,15 @@ CREATE INDEX hosts_status ON hosts(status);
CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at); CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at);
-- Pending one-time enrollment tokens (TTL'd, single-use). -- Pending one-time enrollment tokens (TTL'd, single-use).
-- consumed_host is audit-only (no FK on purpose: we burn the token
-- before the host row exists, and we want this trail to survive a
-- later host deletion).
CREATE TABLE enrollment_tokens ( CREATE TABLE enrollment_tokens (
token_hash TEXT PRIMARY KEY, -- argon2id of token token_hash TEXT PRIMARY KEY,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
expires_at TEXT NOT NULL, expires_at TEXT NOT NULL,
consumed_at TEXT, consumed_at TEXT,
consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL consumed_host TEXT
); );
CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at); CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at);