Files
restic-manager/internal/agent/wsclient/client.go
T
steve f0dfa689fe P3 follow-up: editable target dir, conditional --no-ownership, UK lint
Three small follow-ups from review:

1. Restore target is now operator-editable. Default value is the
   literal '\$HOME/rm-restore/<job-id>/' (agent expands \$HOME at
   run time using os.UserHomeDir(); also handles \${HOME} and ~/
   prefixes). Operator can replace with any absolute path.
   - ui_restore.go validates the input is either absolute or starts
     with one of the recognised prefixes; other env-var refs (\$PATH
     etc.) are deliberately rejected so operator paths can't pick up
     arbitrary agent env values.
   - host_restore.html replaces the read-only mono-text display with
     a real <input>; help text spells out that \$HOME resolves
     agent-side and <job-id> is substituted on dispatch.
   - install.sh + the systemd unit prep /root/rm-restore so the
     default works under the sandbox: ReadWritePaths gains a soft
     '-/root/rm-restore' entry (the '-' makes the bind-mount soft-fail
     if missing, but install.sh pre-creates it root-owned 0700).

2. --no-ownership flag now gated on restic version. The flag was
   added in restic 0.17 and 0.16 rejects it. Previously dropped it
   wholesale — that meant new-dir restores silently preserved
   ownership against design intent on 0.17+. Now the agent threads
   its detected restic version (sysinfo already collects it) through
   runner.Config -> restic.Env, and RunRestore appends --no-ownership
   only when AtLeastVersion(0, 17) returns true. 0.16 hosts still
   restore with original uid/gid; help text in the wizard explicitly
   notes this. The previous 'Original ownership is preserved' copy
   was wrong for new-dir mode and is corrected.

3. golangci-lint misspell locale switched US -> UK and the codebase
   swept (73 corrections, mostly behaviour/serialise/recognise/honour).
   Wire-format ErrorCode 'unauthorized' -> 'unauthorised' is a tiny
   contract change but the agent doesn't parse those codes today and
   no external API consumers exist yet. Tests passed before + after.

Tests:
- internal/restic/version_test.go covers Env.AtLeastVersion across
  edge cases (empty, exact match, patch above, minor below, non-
  numeric) and expandHome on \$HOME / \${HOME} / ~/, plus
  pass-through for absolute paths and refusal of other env vars.
- ui_restore_test updated: TargetDir now starts '\$HOME/rm-restore/'
  with the job_id substituted into the placeholder.

Live verified on the smoke env: default target restored to
/root/rm-restore/<job-id>/ as the agent's expanded \$HOME (2 files,
14 bytes); custom override '/tmp/custom-restore/<job-id>/' restored
into the agent's PrivateTmp namespace (1 file, 6 bytes); both jobs
'succeeded', exit 0.
2026-05-04 17:27:52 +01:00

287 lines
7.9 KiB
Go

// Package wsclient is the agent's outbound WebSocket connection to
// the control plane: dial with bearer auth, perform the hello
// handshake, send heartbeats, dispatch server-pushed commands.
//
// The Run loop is a forever-loop with exponential backoff on dial
// failures, capped at 60s. Disconnected agents keep retrying.
package wsclient
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand"
stdhttp "net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
)
// Config holds the agent's connection settings.
type Config struct {
ServerURL string
AgentToken string
HostID string
CertPinSHA256 string // hex; empty disables pinning
HeartbeatPeriod time.Duration
HelloPayload api.HelloPayload
}
// Sender is what handlers use to push agent → server messages
// (job.progress, job.finished, log.stream, command.result, …).
// Returned by the WS client to the dispatch handler. Write operations
// serialise behind a single mutex on the conn; concurrent calls are
// safe.
type Sender interface {
Send(env api.Envelope) error
}
// Handler is invoked for every server-sent message. tx lets the
// handler push replies back; it is valid only for the lifetime of
// the connection (calls fail if the agent has reconnected since).
type Handler func(ctx context.Context, env api.Envelope, tx Sender) error
// Run keeps the agent connected indefinitely. Returns when ctx is
// canceled. Errors during a single connection attempt are logged and
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
func Run(ctx context.Context, cfg Config, handle Handler) error {
if cfg.HeartbeatPeriod <= 0 {
cfg.HeartbeatPeriod = 30 * time.Second
}
backoff := newBackoff(time.Second, 60*time.Second)
for {
err := connectOnce(ctx, cfg, handle)
if errors.Is(err, context.Canceled) {
return nil
}
if err != nil {
slog.Warn("ws agent disconnect", "err", err)
}
if err := sleepCtx(ctx, backoff.next()); err != nil {
// ctx cancellation mid-backoff means the parent shut us down —
// exit the reconnect loop quietly rather than propagating
// a context error up to a caller that will discard it.
return nil //nolint:nilerr
}
}
}
// connectOnce performs one full connection lifecycle: dial → hello →
// heartbeat loop + read loop → close. Returns when either side closes
// the socket.
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
wsURL, err := buildWSURL(cfg.ServerURL)
if err != nil {
return fmt.Errorf("ws agent: bad server url: %w", err)
}
dialOpts := &websocket.DialOptions{
HTTPHeader: stdhttp.Header{
"Authorization": []string{"Bearer " + cfg.AgentToken},
},
}
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
dialOpts.HTTPClient = &stdhttp.Client{
Transport: &stdhttp.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
},
},
}
}
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
conn, res, err := websocket.Dial(dialCtx, wsURL, dialOpts)
cancel()
if err != nil {
return fmt.Errorf("dial: %w", err)
}
// websocket.Dial returns the upgrade response separately from the
// conn. Body is empty on a successful upgrade but Go's net/http
// still expects it closed to release the connection.
defer func() { _ = res.Body.Close() }()
defer conn.CloseNow() //nolint:errcheck
// Send hello.
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
if err != nil {
return fmt.Errorf("marshal hello: %w", err)
}
if err := writeEnv(ctx, conn, helloEnv); err != nil {
return fmt.Errorf("write hello: %w", err)
}
slog.Info("ws agent connected", "server", wsURL)
tx := &connSender{conn: conn, ctx: ctx}
// Heartbeat goroutine.
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
defer cancelHeartbeat()
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
// Read loop. A read error returns and closes the conn.
for {
mt, raw, err := conn.Read(ctx)
if err != nil {
return fmt.Errorf("read: %w", err)
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
slog.Warn("ws agent: bad envelope from server", "err", err)
continue
}
if env.Type == api.MsgError {
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Error("ws agent: server reported error",
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
// protocol_too_old is fatal — keep retrying won't help.
if ep.Code == api.ErrProtocolTooOld {
return fmt.Errorf("protocol too old: %s", ep.Message)
}
continue
}
if handle != nil {
if err := handle(ctx, env, tx); err != nil {
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
}
}
}
}
// connSender is the per-connection Sender. Goroutines beyond the
// read loop (e.g. a backup running in its own goroutine) keep a
// reference to one of these for the duration of their work.
type connSender struct {
conn *websocket.Conn
ctx context.Context
mu sync.Mutex
}
func (s *connSender) Send(env api.Envelope) error {
s.mu.Lock()
defer s.mu.Unlock()
raw, err := json.Marshal(env)
if err != nil {
return err
}
writeCtx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
defer cancel()
return s.conn.Write(writeCtx, websocket.MessageText, raw)
}
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
t := time.NewTicker(period)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
env, err := api.Marshal(api.MsgHeartbeat, "",
api.HeartbeatPayload{SentAt: time.Now().UTC()})
if err != nil {
continue
}
if err := writeEnv(ctx, conn, env); err != nil {
slog.Warn("ws agent: heartbeat write failed", "err", err)
return
}
}
}
}
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
raw, err := json.Marshal(env)
if err != nil {
return err
}
return conn.Write(ctx, websocket.MessageText, raw)
}
func buildWSURL(serverURL string) (string, error) {
u, err := url.Parse(serverURL)
if err != nil {
return "", err
}
switch u.Scheme {
case "https":
u.Scheme = "wss"
case "http":
u.Scheme = "ws"
case "ws", "wss":
// already correct
default:
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
}
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
return u.String(), nil
}
// pinChecker returns a VerifyPeerCertificate callback that requires
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
// to the OS root verification (we don't replace it).
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("ws agent: no peer certs")
}
got := sha256Hex(rawCerts[0])
if got != wantHex {
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
}
return nil
}
}
func sha256Hex(b []byte) string {
// avoid pulling in crypto/sha256 in this top-level file twice;
// indirection through hex-encode is the classic shape.
h := newSHA256()
h.Write(b)
return hex.EncodeToString(h.Sum(nil))
}
// ----- backoff -------------------------------------------------------
type backoff struct {
cur, max time.Duration
}
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
func (b *backoff) next() time.Duration {
d := b.cur
// 20% jitter, deterministic-enough randomness.
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
b.cur *= 2
if b.cur > b.max {
b.cur = b.max
}
return d + jitter
}
func sleepCtx(ctx context.Context, d time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(d):
return nil
}
}