// Package wsclient is the agent's outbound WebSocket connection to // the control plane: dial with bearer auth, perform the hello // handshake, send heartbeats, dispatch server-pushed commands. // // The Run loop is a forever-loop with exponential backoff on dial // failures, capped at 60s. Disconnected agents keep retrying. package wsclient import ( "context" "crypto/tls" "crypto/x509" "encoding/hex" "encoding/json" "errors" "fmt" "log/slog" "math/rand" stdhttp "net/http" "net/url" "strings" "sync" "time" "github.com/coder/websocket" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" ) // Config holds the agent's connection settings. type Config struct { ServerURL string AgentToken string HostID string CertPinSHA256 string // hex; empty disables pinning HeartbeatPeriod time.Duration HelloPayload api.HelloPayload } // Sender is what handlers use to push agent → server messages // (job.progress, job.finished, log.stream, command.result, …). // Returned by the WS client to the dispatch handler. Write operations // serialise behind a single mutex on the conn; concurrent calls are // safe. type Sender interface { Send(env api.Envelope) error } // Handler is invoked for every server-sent message. tx lets the // handler push replies back; it is valid only for the lifetime of // the connection (calls fail if the agent has reconnected since). type Handler func(ctx context.Context, env api.Envelope, tx Sender) error // Run keeps the agent connected indefinitely. Returns when ctx is // cancelled. Errors during a single connection attempt are logged and // trigger reconnect-with-backoff; only ctx.Done() ends the loop. func Run(ctx context.Context, cfg Config, handle Handler) error { if cfg.HeartbeatPeriod <= 0 { cfg.HeartbeatPeriod = 30 * time.Second } backoff := newBackoff(time.Second, 60*time.Second) for { err := connectOnce(ctx, cfg, handle) if errors.Is(err, context.Canceled) { return nil } if err != nil { slog.Warn("ws agent disconnect", "err", err) } if err := sleepCtx(ctx, backoff.next()); err != nil { return nil } } } // connectOnce performs one full connection lifecycle: dial → hello → // heartbeat loop + read loop → close. Returns when either side closes // the socket. func connectOnce(ctx context.Context, cfg Config, handle Handler) error { wsURL, err := buildWSURL(cfg.ServerURL) if err != nil { return fmt.Errorf("ws agent: bad server url: %w", err) } dialOpts := &websocket.DialOptions{ HTTPHeader: stdhttp.Header{ "Authorization": []string{"Bearer " + cfg.AgentToken}, }, } if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") { dialOpts.HTTPClient = &stdhttp.Client{ Transport: &stdhttp.Transport{ TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256), }, }, } } dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second) conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts) cancel() if err != nil { return fmt.Errorf("dial: %w", err) } defer conn.CloseNow() //nolint:errcheck // Send hello. helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload) if err != nil { return fmt.Errorf("marshal hello: %w", err) } if err := writeEnv(ctx, conn, helloEnv); err != nil { return fmt.Errorf("write hello: %w", err) } slog.Info("ws agent connected", "server", wsURL) tx := &connSender{conn: conn, ctx: ctx} // Heartbeat goroutine. heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx) defer cancelHeartbeat() go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod) // Read loop. A read error returns and closes the conn. for { mt, raw, err := conn.Read(ctx) if err != nil { return fmt.Errorf("read: %w", err) } if mt != websocket.MessageText { continue } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { slog.Warn("ws agent: bad envelope from server", "err", err) continue } if env.Type == api.MsgError { var ep api.ErrorPayload _ = env.UnmarshalPayload(&ep) slog.Error("ws agent: server reported error", "code", ep.Code, "message", ep.Message, "help", ep.HelpURL) // protocol_too_old is fatal — keep retrying won't help. if ep.Code == api.ErrProtocolTooOld { return fmt.Errorf("protocol too old: %s", ep.Message) } continue } if handle != nil { if err := handle(ctx, env, tx); err != nil { slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err) } } } } // connSender is the per-connection Sender. Goroutines beyond the // read loop (e.g. a backup running in its own goroutine) keep a // reference to one of these for the duration of their work. type connSender struct { conn *websocket.Conn ctx context.Context mu sync.Mutex } func (s *connSender) Send(env api.Envelope) error { s.mu.Lock() defer s.mu.Unlock() raw, err := json.Marshal(env) if err != nil { return err } writeCtx, cancel := context.WithTimeout(s.ctx, 30*time.Second) defer cancel() return s.conn.Write(writeCtx, websocket.MessageText, raw) } func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) { t := time.NewTicker(period) defer t.Stop() for { select { case <-ctx.Done(): return case <-t.C: env, err := api.Marshal(api.MsgHeartbeat, "", api.HeartbeatPayload{SentAt: time.Now().UTC()}) if err != nil { continue } if err := writeEnv(ctx, conn, env); err != nil { slog.Warn("ws agent: heartbeat write failed", "err", err) return } } } } func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error { raw, err := json.Marshal(env) if err != nil { return err } return conn.Write(ctx, websocket.MessageText, raw) } func buildWSURL(serverURL string) (string, error) { u, err := url.Parse(serverURL) if err != nil { return "", err } switch u.Scheme { case "https": u.Scheme = "wss" case "http": u.Scheme = "ws" case "ws", "wss": // already correct default: return "", fmt.Errorf("unsupported scheme %q", u.Scheme) } u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent" return u.String(), nil } // pinChecker returns a VerifyPeerCertificate callback that requires // the leaf cert's SHA-256 to match wantHex. We do this *in addition* // to the OS root verification (we don't replace it). func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error { return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { if len(rawCerts) == 0 { return errors.New("ws agent: no peer certs") } got := sha256Hex(rawCerts[0]) if got != wantHex { return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex) } return nil } } func sha256Hex(b []byte) string { // avoid pulling in crypto/sha256 in this top-level file twice; // indirection through hex-encode is the classic shape. h := newSHA256() h.Write(b) return hex.EncodeToString(h.Sum(nil)) } // ----- backoff ------------------------------------------------------- type backoff struct { cur, max time.Duration } func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} } func (b *backoff) next() time.Duration { d := b.cur // 20% jitter, deterministic-enough randomness. jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec b.cur *= 2 if b.cur > b.max { b.cur = b.max } return d + jitter } func sleepCtx(ctx context.Context, d time.Duration) error { select { case <-ctx.Done(): return ctx.Err() case <-time.After(d): return nil } }