a7c6a6e09c
Lands the operator → server → agent → restic → server roundtrip for
on-demand backups. The flow:
POST /api/hosts/{id}/jobs {kind:"backup",args:["/path"]}
→ server creates a queued Job row
→ server emits command.run over WS to the host's agent
→ agent dispatcher spawns runner.RunBackup in a goroutine
→ runner spawns `restic backup --json`, parses each line
→ forwards: job.started, log.stream (every line), job.progress
(throttled to 1/sec), job.finished (with summary stats blob)
→ server WS handler persists those into jobs / job_logs
P1-16 internal/restic: thin Locate + Env wrapper that runs `restic
backup --json`, scans stdout/stderr, parses BackupStatus +
BackupSummary, calls back into a LineHandler so the agent can fan
out to log.stream + job.progress. Treats exit code 3 as
"succeeded with issues" (matches restic's contract).
P1-18 store: jobs accessors (CreateJob, MarkJobStarted,
MarkJobFinished, AppendJobLog, GetJob).
P1-19 server: POST /api/hosts/{id}/jobs creates the Job row,
validates kind, dispatches via Hub.Send, audit-logs the action.
P1-20 agent runner: wraps restic.RunBackup with throttled progress
emission. Sender abstraction was added to wsclient.Handler so
background goroutines can keep replying after dispatch returns.
P1-21 server WS: dispatchAgentMessage now persists job.started,
job.finished, log.stream into the database. Browser fan-out for
live tailing lands with the UI work.
Agent gets repo_url + repo_password from agent.yaml in plaintext
for now (mode 0600, owned by service user); spec.md §7.3's keyring
storage moves there in P2. config.update over WS overrides the
in-memory copy (does not persist).
Build clean; all tests pass. End-to-end with a real restic still
needs a host that has restic installed — wire shape verified by
the existing hello/heartbeat round-trip test.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
280 lines
7.5 KiB
Go
280 lines
7.5 KiB
Go
// Package wsclient is the agent's outbound WebSocket connection to
|
|
// the control plane: dial with bearer auth, perform the hello
|
|
// handshake, send heartbeats, dispatch server-pushed commands.
|
|
//
|
|
// The Run loop is a forever-loop with exponential backoff on dial
|
|
// failures, capped at 60s. Disconnected agents keep retrying.
|
|
package wsclient
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/rand"
|
|
stdhttp "net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
|
)
|
|
|
|
// Config holds the agent's connection settings.
|
|
type Config struct {
|
|
ServerURL string
|
|
AgentToken string
|
|
HostID string
|
|
CertPinSHA256 string // hex; empty disables pinning
|
|
HeartbeatPeriod time.Duration
|
|
HelloPayload api.HelloPayload
|
|
}
|
|
|
|
// Sender is what handlers use to push agent → server messages
|
|
// (job.progress, job.finished, log.stream, command.result, …).
|
|
// Returned by the WS client to the dispatch handler. Write operations
|
|
// serialise behind a single mutex on the conn; concurrent calls are
|
|
// safe.
|
|
type Sender interface {
|
|
Send(env api.Envelope) error
|
|
}
|
|
|
|
// Handler is invoked for every server-sent message. tx lets the
|
|
// handler push replies back; it is valid only for the lifetime of
|
|
// the connection (calls fail if the agent has reconnected since).
|
|
type Handler func(ctx context.Context, env api.Envelope, tx Sender) error
|
|
|
|
// Run keeps the agent connected indefinitely. Returns when ctx is
|
|
// cancelled. Errors during a single connection attempt are logged and
|
|
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
|
|
func Run(ctx context.Context, cfg Config, handle Handler) error {
|
|
if cfg.HeartbeatPeriod <= 0 {
|
|
cfg.HeartbeatPeriod = 30 * time.Second
|
|
}
|
|
|
|
backoff := newBackoff(time.Second, 60*time.Second)
|
|
for {
|
|
err := connectOnce(ctx, cfg, handle)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
slog.Warn("ws agent disconnect", "err", err)
|
|
}
|
|
if err := sleepCtx(ctx, backoff.next()); err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// connectOnce performs one full connection lifecycle: dial → hello →
|
|
// heartbeat loop + read loop → close. Returns when either side closes
|
|
// the socket.
|
|
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
|
|
wsURL, err := buildWSURL(cfg.ServerURL)
|
|
if err != nil {
|
|
return fmt.Errorf("ws agent: bad server url: %w", err)
|
|
}
|
|
|
|
dialOpts := &websocket.DialOptions{
|
|
HTTPHeader: stdhttp.Header{
|
|
"Authorization": []string{"Bearer " + cfg.AgentToken},
|
|
},
|
|
}
|
|
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
|
|
dialOpts.HTTPClient = &stdhttp.Client{
|
|
Transport: &stdhttp.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts)
|
|
cancel()
|
|
if err != nil {
|
|
return fmt.Errorf("dial: %w", err)
|
|
}
|
|
defer conn.CloseNow() //nolint:errcheck
|
|
|
|
// Send hello.
|
|
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal hello: %w", err)
|
|
}
|
|
if err := writeEnv(ctx, conn, helloEnv); err != nil {
|
|
return fmt.Errorf("write hello: %w", err)
|
|
}
|
|
slog.Info("ws agent connected", "server", wsURL)
|
|
|
|
tx := &connSender{conn: conn, ctx: ctx}
|
|
|
|
// Heartbeat goroutine.
|
|
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
|
|
defer cancelHeartbeat()
|
|
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
|
|
|
|
// Read loop. A read error returns and closes the conn.
|
|
for {
|
|
mt, raw, err := conn.Read(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("read: %w", err)
|
|
}
|
|
if mt != websocket.MessageText {
|
|
continue
|
|
}
|
|
var env api.Envelope
|
|
if err := json.Unmarshal(raw, &env); err != nil {
|
|
slog.Warn("ws agent: bad envelope from server", "err", err)
|
|
continue
|
|
}
|
|
if env.Type == api.MsgError {
|
|
var ep api.ErrorPayload
|
|
_ = env.UnmarshalPayload(&ep)
|
|
slog.Error("ws agent: server reported error",
|
|
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
|
|
// protocol_too_old is fatal — keep retrying won't help.
|
|
if ep.Code == api.ErrProtocolTooOld {
|
|
return fmt.Errorf("protocol too old: %s", ep.Message)
|
|
}
|
|
continue
|
|
}
|
|
if handle != nil {
|
|
if err := handle(ctx, env, tx); err != nil {
|
|
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// connSender is the per-connection Sender. Goroutines beyond the
|
|
// read loop (e.g. a backup running in its own goroutine) keep a
|
|
// reference to one of these for the duration of their work.
|
|
type connSender struct {
|
|
conn *websocket.Conn
|
|
ctx context.Context
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (s *connSender) Send(env api.Envelope) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
raw, err := json.Marshal(env)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
writeCtx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
|
|
defer cancel()
|
|
return s.conn.Write(writeCtx, websocket.MessageText, raw)
|
|
}
|
|
|
|
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
|
|
t := time.NewTicker(period)
|
|
defer t.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-t.C:
|
|
env, err := api.Marshal(api.MsgHeartbeat, "",
|
|
api.HeartbeatPayload{SentAt: time.Now().UTC()})
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if err := writeEnv(ctx, conn, env); err != nil {
|
|
slog.Warn("ws agent: heartbeat write failed", "err", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
|
|
raw, err := json.Marshal(env)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return conn.Write(ctx, websocket.MessageText, raw)
|
|
}
|
|
|
|
func buildWSURL(serverURL string) (string, error) {
|
|
u, err := url.Parse(serverURL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
switch u.Scheme {
|
|
case "https":
|
|
u.Scheme = "wss"
|
|
case "http":
|
|
u.Scheme = "ws"
|
|
case "ws", "wss":
|
|
// already correct
|
|
default:
|
|
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
|
|
}
|
|
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
|
|
return u.String(), nil
|
|
}
|
|
|
|
// pinChecker returns a VerifyPeerCertificate callback that requires
|
|
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
|
|
// to the OS root verification (we don't replace it).
|
|
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
|
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
|
if len(rawCerts) == 0 {
|
|
return errors.New("ws agent: no peer certs")
|
|
}
|
|
got := sha256Hex(rawCerts[0])
|
|
if got != wantHex {
|
|
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func sha256Hex(b []byte) string {
|
|
// avoid pulling in crypto/sha256 in this top-level file twice;
|
|
// indirection through hex-encode is the classic shape.
|
|
h := newSHA256()
|
|
h.Write(b)
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// ----- backoff -------------------------------------------------------
|
|
|
|
type backoff struct {
|
|
cur, max time.Duration
|
|
}
|
|
|
|
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
|
|
|
|
func (b *backoff) next() time.Duration {
|
|
d := b.cur
|
|
// 20% jitter, deterministic-enough randomness.
|
|
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
|
|
b.cur *= 2
|
|
if b.cur > b.max {
|
|
b.cur = b.max
|
|
}
|
|
return d + jitter
|
|
}
|
|
|
|
func sleepCtx(ctx context.Context, d time.Duration) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(d):
|
|
return nil
|
|
}
|
|
}
|