phase 1: WS transport, enrollment, agent that hellos and heartbeats
Lands the protocol layer end-to-end: an agent can be enrolled through the operator UI, store credentials, dial back to the server over WS, complete the protocol_version handshake, and stay connected with periodic heartbeats. Server side: - P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction, json envelope writer with a write mutex, reader, error envelopes. - P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage (10s deadline, protocol_version checked against api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on reject), main read loop, defer hub register/unregister. - P1-10 POST /api/agents/enroll consumes a one-time token, mints a persistent agent bearer (sha-256 stored), creates a host row. - P1-10 POST /api/enrollment-tokens (operator, session-auth) issues a 1h one-time token. - P1-11 hello upserts agent_version + restic_version + protocol_version on the host row, flips status to online. - P1-12 heartbeat touches last_seen_at; background sweeper marks hosts offline after 90s without one. - store: hosts table accessors, host_schedule_version, enrollment_tokens FK on consumed_host dropped (audit-only field; the token gets burned before the host row exists). Agent side: - P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml, atomic Save (tmp+fsync+rename), Enrolled() helper. - P1-15 internal/agent/wsclient: dial with bearer + optional TLS cert pinning (sha-256 of leaf), exponential backoff with jitter (1s → 60s cap), heartbeat goroutine, fatal handling for ErrProtocolTooOld. - P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo. - P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version collection. restic detected by `restic version` parse; absent restic doesn't block startup. - cmd/agent: -enroll-server / -enroll-token flags drive first-run enrollment then exit (so the install script can hand off to systemd to run the persistent service). End-to-end smoke verified: bootstrap → login → issue token → enroll → run agent → server logs `ws agent connected` with the right host_id and protocol_version 1. All tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,110 @@
|
||||
// Package config loads the agent's persistent configuration. After
|
||||
// enrollment, the file holds the bearer token + server URL; it is
|
||||
// only ever written via Save (which replaces atomically).
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config is the on-disk shape of the agent's config file.
|
||||
type Config struct {
|
||||
// ServerURL is the base URL of the control plane, e.g.
|
||||
// https://restic.lab.example. The agent appends /ws/agent and
|
||||
// /api/agents/enroll.
|
||||
ServerURL string `yaml:"server_url"`
|
||||
|
||||
// AgentToken is the bearer credential issued at enrollment.
|
||||
// Empty means "not yet enrolled."
|
||||
AgentToken string `yaml:"agent_token"`
|
||||
|
||||
// HostID is what the server thinks this host is.
|
||||
HostID string `yaml:"host_id"`
|
||||
|
||||
// CertPinSHA256 (optional) is the SHA-256 of the server's TLS
|
||||
// cert. When set, the agent refuses to connect to a server
|
||||
// whose cert hash doesn't match.
|
||||
CertPinSHA256 string `yaml:"cert_pin_sha256,omitempty"`
|
||||
|
||||
// ResticPath overrides the auto-detected restic binary path.
|
||||
ResticPath string `yaml:"restic_path,omitempty"`
|
||||
|
||||
// path is the file we loaded from. Used by Save.
|
||||
path string `yaml:"-"`
|
||||
}
|
||||
|
||||
// DefaultPath returns the canonical config path for the current OS.
|
||||
// Phase 1 ships Linux only; Windows path lives in the spec for P2.
|
||||
func DefaultPath() string {
|
||||
return "/etc/restic-manager/agent.yaml"
|
||||
}
|
||||
|
||||
// Load reads and parses the config file at path. A missing file is
|
||||
// returned as an empty Config (not an error) — first-run agents
|
||||
// haven't been enrolled yet.
|
||||
func Load(path string) (*Config, error) {
|
||||
c := &Config{path: path}
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return c, nil
|
||||
}
|
||||
return nil, fmt.Errorf("agent config: read %q: %w", path, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(body, c); err != nil {
|
||||
return nil, fmt.Errorf("agent config: parse %q: %w", path, err)
|
||||
}
|
||||
c.path = path
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Save writes the config back atomically: write to <path>.tmp, fsync,
|
||||
// rename. A crash mid-write either leaves the old file or the new one,
|
||||
// never a half-written one.
|
||||
func (c *Config) Save() error {
|
||||
if c.path == "" {
|
||||
return fmt.Errorf("agent config: no path set")
|
||||
}
|
||||
dir := filepath.Dir(c.path)
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return fmt.Errorf("agent config: mkdir %q: %w", dir, err)
|
||||
}
|
||||
body, err := yaml.Marshal(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent config: marshal: %w", err)
|
||||
}
|
||||
|
||||
tmp := c.path + ".tmp"
|
||||
f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent config: create tmp: %w", err)
|
||||
}
|
||||
if _, err := f.Write(body); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: write tmp: %w", err)
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: fsync tmp: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: close tmp: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, c.path); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("agent config: rename: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enrolled reports whether the agent has finished enrollment.
|
||||
func (c *Config) Enrolled() bool {
|
||||
return c.AgentToken != "" && c.HostID != "" && c.ServerURL != ""
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
// Package sysinfo collects host metadata at agent startup: OS, arch,
|
||||
// hostname, restic version. The agent sends this in `hello` so the
|
||||
// server's Host row stays current.
|
||||
package sysinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// Snapshot is the bundle of metadata reported in `hello`.
|
||||
type Snapshot struct {
|
||||
Hostname string
|
||||
OS api.HostOS
|
||||
Arch api.HostArch
|
||||
ResticVersion string
|
||||
ProtocolVersion int
|
||||
BootTime time.Time
|
||||
}
|
||||
|
||||
// Collect probes the running host. resticPath, if non-empty,
|
||||
// overrides PATH lookup.
|
||||
func Collect(ctx context.Context, resticPath string) (Snapshot, error) {
|
||||
hn, err := os.Hostname()
|
||||
if err != nil {
|
||||
return Snapshot{}, fmt.Errorf("sysinfo: hostname: %w", err)
|
||||
}
|
||||
|
||||
osTag := api.HostOS(runtime.GOOS)
|
||||
archTag := api.HostArch(runtime.GOARCH)
|
||||
|
||||
resticVer, _ := detectResticVersion(ctx, resticPath) // empty on failure is fine
|
||||
return Snapshot{
|
||||
Hostname: hn,
|
||||
OS: osTag,
|
||||
Arch: archTag,
|
||||
ResticVersion: resticVer,
|
||||
ProtocolVersion: api.CurrentProtocolVersion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectResticVersion runs `restic version` and parses the first line.
|
||||
// Output looks like:
|
||||
// restic 0.17.1 compiled with go1.22.5 on linux/amd64
|
||||
// Returns the version token (e.g. "0.17.1") or "" if restic isn't
|
||||
// found. We never block startup on a missing restic — the operator
|
||||
// might not have installed it yet, and the agent should still be
|
||||
// able to connect and report.
|
||||
func detectResticVersion(ctx context.Context, override string) (string, error) {
|
||||
bin := override
|
||||
if bin == "" {
|
||||
var err error
|
||||
bin, err = exec.LookPath("restic")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
versionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
out, err := exec.CommandContext(versionCtx, bin, "version").Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
first := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0]
|
||||
parts := strings.Fields(first)
|
||||
if len(parts) >= 2 && parts[0] == "restic" {
|
||||
return parts[1], nil
|
||||
}
|
||||
return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first)
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
// Package wsclient is the agent's outbound WebSocket connection to
|
||||
// the control plane: dial with bearer auth, perform the hello
|
||||
// handshake, send heartbeats, dispatch server-pushed commands.
|
||||
//
|
||||
// The Run loop is a forever-loop with exponential backoff on dial
|
||||
// failures, capped at 60s. Disconnected agents keep retrying.
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// Config holds the agent's connection settings.
|
||||
type Config struct {
|
||||
ServerURL string
|
||||
AgentToken string
|
||||
HostID string
|
||||
CertPinSHA256 string // hex; empty disables pinning
|
||||
HeartbeatPeriod time.Duration
|
||||
HelloPayload api.HelloPayload
|
||||
}
|
||||
|
||||
// Handler is invoked for every server-sent message. The agent's main
|
||||
// program supplies one that knows how to dispatch command.run etc.
|
||||
// to the runner package.
|
||||
type Handler func(ctx context.Context, env api.Envelope) error
|
||||
|
||||
// Run keeps the agent connected indefinitely. Returns when ctx is
|
||||
// cancelled. Errors during a single connection attempt are logged and
|
||||
// trigger reconnect-with-backoff; only ctx.Done() ends the loop.
|
||||
func Run(ctx context.Context, cfg Config, handle Handler) error {
|
||||
if cfg.HeartbeatPeriod <= 0 {
|
||||
cfg.HeartbeatPeriod = 30 * time.Second
|
||||
}
|
||||
|
||||
backoff := newBackoff(time.Second, 60*time.Second)
|
||||
for {
|
||||
err := connectOnce(ctx, cfg, handle)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("ws agent disconnect", "err", err)
|
||||
}
|
||||
if err := sleepCtx(ctx, backoff.next()); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connectOnce performs one full connection lifecycle: dial → hello →
|
||||
// heartbeat loop + read loop → close. Returns when either side closes
|
||||
// the socket.
|
||||
func connectOnce(ctx context.Context, cfg Config, handle Handler) error {
|
||||
wsURL, err := buildWSURL(cfg.ServerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws agent: bad server url: %w", err)
|
||||
}
|
||||
|
||||
dialOpts := &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{
|
||||
"Authorization": []string{"Bearer " + cfg.AgentToken},
|
||||
},
|
||||
}
|
||||
if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") {
|
||||
dialOpts.HTTPClient = &stdhttp.Client{
|
||||
Transport: &stdhttp.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
defer conn.CloseNow() //nolint:errcheck
|
||||
|
||||
// Send hello.
|
||||
helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal hello: %w", err)
|
||||
}
|
||||
if err := writeEnv(ctx, conn, helloEnv); err != nil {
|
||||
return fmt.Errorf("write hello: %w", err)
|
||||
}
|
||||
slog.Info("ws agent connected", "server", wsURL)
|
||||
|
||||
// Heartbeat goroutine.
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
|
||||
defer cancelHeartbeat()
|
||||
go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod)
|
||||
|
||||
// Read loop. A read error returns and closes the conn.
|
||||
for {
|
||||
mt, raw, err := conn.Read(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read: %w", err)
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
var env api.Envelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
slog.Warn("ws agent: bad envelope from server", "err", err)
|
||||
continue
|
||||
}
|
||||
if env.Type == api.MsgError {
|
||||
var ep api.ErrorPayload
|
||||
_ = env.UnmarshalPayload(&ep)
|
||||
slog.Error("ws agent: server reported error",
|
||||
"code", ep.Code, "message", ep.Message, "help", ep.HelpURL)
|
||||
// protocol_too_old is fatal — keep retrying won't help.
|
||||
if ep.Code == api.ErrProtocolTooOld {
|
||||
return fmt.Errorf("protocol too old: %s", ep.Message)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if handle != nil {
|
||||
if err := handle(ctx, env); err != nil {
|
||||
slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) {
|
||||
t := time.NewTicker(period)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
env, err := api.Marshal(api.MsgHeartbeat, "",
|
||||
api.HeartbeatPayload{SentAt: time.Now().UTC()})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := writeEnv(ctx, conn, env); err != nil {
|
||||
slog.Warn("ws agent: heartbeat write failed", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error {
|
||||
raw, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Write(ctx, websocket.MessageText, raw)
|
||||
}
|
||||
|
||||
func buildWSURL(serverURL string) (string, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
case "http":
|
||||
u.Scheme = "ws"
|
||||
case "ws", "wss":
|
||||
// already correct
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported scheme %q", u.Scheme)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// pinChecker returns a VerifyPeerCertificate callback that requires
|
||||
// the leaf cert's SHA-256 to match wantHex. We do this *in addition*
|
||||
// to the OS root verification (we don't replace it).
|
||||
func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return errors.New("ws agent: no peer certs")
|
||||
}
|
||||
got := sha256Hex(rawCerts[0])
|
||||
if got != wantHex {
|
||||
return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sha256Hex(b []byte) string {
|
||||
// avoid pulling in crypto/sha256 in this top-level file twice;
|
||||
// indirection through hex-encode is the classic shape.
|
||||
h := newSHA256()
|
||||
h.Write(b)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// ----- backoff -------------------------------------------------------
|
||||
|
||||
type backoff struct {
|
||||
cur, max time.Duration
|
||||
}
|
||||
|
||||
func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} }
|
||||
|
||||
func (b *backoff) next() time.Duration {
|
||||
d := b.cur
|
||||
// 20% jitter, deterministic-enough randomness.
|
||||
jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec
|
||||
b.cur *= 2
|
||||
if b.cur > b.max {
|
||||
b.cur = b.max
|
||||
}
|
||||
return d + jitter
|
||||
}
|
||||
|
||||
func sleepCtx(ctx context.Context, d time.Duration) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(d):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// EnrollRequest is what we POST to /api/agents/enroll.
|
||||
type EnrollRequest struct {
|
||||
Token string `json:"token"`
|
||||
HostName string `json:"hostname"`
|
||||
OS api.HostOS `json:"os"`
|
||||
Arch api.HostArch `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
}
|
||||
|
||||
// EnrollResponse is what the server hands back.
|
||||
type EnrollResponse struct {
|
||||
HostID string `json:"host_id"`
|
||||
AgentToken string `json:"agent_token"`
|
||||
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
|
||||
}
|
||||
|
||||
// Enroll exchanges a one-time enrollment token for persistent agent
|
||||
// credentials. Called by the install script on first run.
|
||||
func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: marshal: %w", err)
|
||||
}
|
||||
postURL := strings.TrimRight(serverURL, "/") + "/api/agents/enroll"
|
||||
|
||||
httpReq, err := stdhttp.NewRequestWithContext(ctx, stdhttp.MethodPost, postURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &stdhttp.Client{Timeout: 30 * time.Second}
|
||||
res, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: post: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
rawRes, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode != stdhttp.StatusCreated {
|
||||
return nil, fmt.Errorf("agent enroll: server returned %d: %s",
|
||||
res.StatusCode, rawRes)
|
||||
}
|
||||
var er EnrollResponse
|
||||
if err := json.Unmarshal(rawRes, &er); err != nil {
|
||||
return nil, fmt.Errorf("agent enroll: parse response: %w", err)
|
||||
}
|
||||
if er.AgentToken == "" || er.HostID == "" {
|
||||
return nil, fmt.Errorf("agent enroll: incomplete response: %+v", er)
|
||||
}
|
||||
return &er, nil
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package wsclient
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
)
|
||||
|
||||
func newSHA256() hash.Hash { return sha256.New() }
|
||||
@@ -0,0 +1,165 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// enrollRequest is the body posted by the agent installer. The token
|
||||
// was issued by the operator via the UI ("Add host" → P1-27); the
|
||||
// host metadata comes from the agent's own sysinfo collection.
|
||||
type enrollRequest struct {
|
||||
Token string `json:"token"`
|
||||
HostName string `json:"hostname"`
|
||||
OS api.HostOS `json:"os"`
|
||||
Arch api.HostArch `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
}
|
||||
|
||||
// enrollResponse hands the agent the credentials it'll use forever.
|
||||
// AgentToken is shown exactly once; the server stores its hash.
|
||||
// CertPinSHA256 is the SHA-256 of the server's certificate, base64;
|
||||
// the agent pins this on every reconnect so a stolen DB at the
|
||||
// control plane can't be replayed against an attacker's TLS endpoint.
|
||||
type enrollResponse struct {
|
||||
HostID string `json:"host_id"`
|
||||
AgentToken string `json:"agent_token"`
|
||||
CertPinSHA256 string `json:"cert_pin_sha256,omitempty"`
|
||||
}
|
||||
|
||||
// enrollOperatorRequest creates a one-time enrollment token for an
|
||||
// operator who is about to install an agent. Authenticated UI route.
|
||||
type enrollOperatorRequest struct {
|
||||
HostName string `json:"hostname"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
type enrollOperatorResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// handleAgentEnroll consumes a one-time token, persists a Host row,
|
||||
// and returns persistent agent credentials. Open endpoint (no
|
||||
// session) — the token is the credential.
|
||||
func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
var req enrollRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
if req.Token == "" || req.HostName == "" || req.OS == "" || req.Arch == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||
"token, hostname, os, arch all required")
|
||||
return
|
||||
}
|
||||
|
||||
hostID := ulid.Make().String()
|
||||
|
||||
// Atomically: validate + consume token, then create the host.
|
||||
// We do these in two statements; if create-host fails, the token
|
||||
// is already burned. That's acceptable — operator just regens.
|
||||
tokHash := auth.HashToken(req.Token)
|
||||
if err := s.deps.Store.ConsumeEnrollmentToken(r.Context(), tokHash, hostID); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "invalid_token",
|
||||
"token unknown, expired, or already used")
|
||||
return
|
||||
}
|
||||
|
||||
// Mint the persistent agent bearer.
|
||||
agentToken, err := auth.NewToken()
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
|
||||
host := store.Host{
|
||||
ID: hostID,
|
||||
Name: strings.TrimSpace(req.HostName),
|
||||
OS: string(req.OS),
|
||||
Arch: string(req.Arch),
|
||||
AgentVersion: req.AgentVersion,
|
||||
ResticVersion: req.ResticVersion,
|
||||
EnrolledAt: time.Now().UTC(),
|
||||
}
|
||||
if err := s.deps.Store.CreateHost(r.Context(), host,
|
||||
auth.HashToken(agentToken), ""); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusConflict, "host_exists", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
Actor: "system",
|
||||
Action: "host.enrolled",
|
||||
TargetKind: ptr("host"),
|
||||
TargetID: &hostID,
|
||||
TS: host.EnrolledAt,
|
||||
})
|
||||
|
||||
writeJSON(w, stdhttp.StatusCreated, enrollResponse{
|
||||
HostID: hostID,
|
||||
AgentToken: agentToken,
|
||||
// CertPinSHA256 is populated by a TLS-aware future revision.
|
||||
// For now (HTTP-or-TLS-by-Caddy) we leave it empty and rely
|
||||
// on the agent trusting its OS root store.
|
||||
})
|
||||
}
|
||||
|
||||
// handleCreateEnrollmentToken (operator-facing) — generates a
|
||||
// short-lived token for a new host. Authenticated; admin/operator only.
|
||||
//
|
||||
// TODO: gate by authn middleware once login session lookup lands.
|
||||
// For Phase 1's first slice, we accept the bootstrap-shipped admin
|
||||
// session cookie and trust it, validating the cookie via store.
|
||||
func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
|
||||
var req enrollOperatorRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.NewToken()
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
const ttl = time.Hour
|
||||
if err := s.deps.Store.CreateEnrollmentToken(r.Context(), auth.HashToken(token), ttl); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, stdhttp.StatusCreated, enrollOperatorResponse{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(ttl).UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
// authedUser returns true iff the request carries a valid session
|
||||
// cookie. Minimal stub for now; full RBAC middleware lands with
|
||||
// P4-03.
|
||||
func (s *Server) authedUser(r *stdhttp.Request) bool {
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, err = s.deps.Store.LookupSession(r.Context(), auth.HashToken(c.Value))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func ptr(s string) *string { return &s }
|
||||
@@ -0,0 +1,118 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
stdhttp "net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// newTestServerWithHub mirrors newTestServer but plugs in a real
|
||||
// ws.Hub so /ws/agent is available.
|
||||
func newTestServerWithHub(t *testing.T) (*Server, string, *store.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
_ = crypto.GenerateKeyFile(keyPath)
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
|
||||
deps := Deps{
|
||||
Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath},
|
||||
Store: st,
|
||||
AEAD: aead,
|
||||
Hub: ws.NewHub(),
|
||||
}
|
||||
s := New(deps)
|
||||
ts := httptest.NewServer(s.srv.Handler)
|
||||
t.Cleanup(ts.Close)
|
||||
return s, ts.URL, st
|
||||
}
|
||||
|
||||
func TestEnrollmentBadToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
|
||||
body, _ := json.Marshal(enrollRequest{
|
||||
Token: "no-such-token", HostName: "host1",
|
||||
OS: api.OSLinux, Arch: api.ArchAmd64,
|
||||
AgentVersion: "0.1", ResticVersion: "0.17",
|
||||
})
|
||||
res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("post: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("status: %d", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrollmentHappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, st := newTestServerWithHub(t)
|
||||
|
||||
// Issue a token directly via the store (skipping the operator UI).
|
||||
rawToken, _ := auth.NewToken()
|
||||
if err := st.CreateEnrollmentToken(context.Background(),
|
||||
auth.HashToken(rawToken), 5*time.Minute); err != nil {
|
||||
t.Fatalf("issue: %v", err)
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(enrollRequest{
|
||||
Token: rawToken, HostName: "test-host",
|
||||
OS: api.OSLinux, Arch: api.ArchAmd64,
|
||||
AgentVersion: "0.1", ResticVersion: "0.17",
|
||||
})
|
||||
res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("post: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusCreated {
|
||||
buf, _ := io.ReadAll(res.Body)
|
||||
t.Fatalf("status %d: %s", res.StatusCode, buf)
|
||||
}
|
||||
|
||||
var er enrollResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&er); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if er.HostID == "" || er.AgentToken == "" {
|
||||
t.Errorf("missing fields in response: %+v", er)
|
||||
}
|
||||
|
||||
// Token must not be reusable.
|
||||
res2, _ := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body))
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("re-enrollment with same token should fail, got %d", res2.StatusCode)
|
||||
}
|
||||
|
||||
// Host row exists with matching agent_token_hash.
|
||||
got, err := st.LookupHostByAgentToken(context.Background(), auth.HashToken(er.AgentToken))
|
||||
if err != nil {
|
||||
t.Fatalf("lookup by token: %v", err)
|
||||
}
|
||||
if got.Name != "test-host" || got.OS != "linux" {
|
||||
t.Errorf("host fields: %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
@@ -24,6 +25,7 @@ type Deps struct {
|
||||
Cfg config.Config
|
||||
Store *store.Store
|
||||
AEAD *crypto.AEAD
|
||||
Hub *ws.Hub
|
||||
// BootstrapToken (optional, populated only on first run) is the raw
|
||||
// admin-bootstrap token printed in the server logs. While set, the
|
||||
// /bootstrap endpoint accepts it to create the first admin user.
|
||||
@@ -73,8 +75,24 @@ func (s *Server) routes(r chi.Router) {
|
||||
r.Post("/auth/login", s.handleLogin)
|
||||
r.Post("/auth/logout", s.handleLogout)
|
||||
r.Post("/bootstrap", s.handleBootstrap)
|
||||
|
||||
// Agent enrollment (open endpoint — token is the credential).
|
||||
r.Post("/agents/enroll", s.handleAgentEnroll)
|
||||
|
||||
// Operator → server (authenticated). Spec.md §6.1's
|
||||
// /hosts/{id}/enrollment-token (regenerate) lands when the
|
||||
// host page can call it; for now just the create endpoint.
|
||||
r.Post("/enrollment-tokens", s.handleCreateEnrollmentToken)
|
||||
})
|
||||
|
||||
// Agent ↔ server WebSocket. Bearer-authenticated inside the handler.
|
||||
if s.deps.Hub != nil {
|
||||
r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{
|
||||
Hub: s.deps.Hub,
|
||||
Store: s.deps.Store,
|
||||
}))
|
||||
}
|
||||
|
||||
// UI handlers will hang off / — Phase 1 will add them.
|
||||
r.Get("/", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
_, _ = fmt.Fprint(w, "restic-manager — UI not yet implemented")
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
// Package ws hosts the WebSocket transport for agent ↔ server and the
|
||||
// browser-facing live job log stream.
|
||||
package ws
|
||||
@@ -0,0 +1,183 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
||||
type HandlerDeps struct {
|
||||
Hub *Hub
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
// AgentHandler is the http.Handler that owns /ws/agent. Agents
|
||||
// authenticate with `Authorization: Bearer <token>` (issued at
|
||||
// enrollment) before the WS upgrade.
|
||||
//
|
||||
// Lifecycle:
|
||||
// 1. Bearer token resolves to a Host row.
|
||||
// 2. Upgrade.
|
||||
// 3. First message must be `hello`; protocol_version checked here.
|
||||
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
|
||||
// host row; job/log/repo messages forward to the relevant
|
||||
// handlers (TODO: lands with P1-18 onward).
|
||||
// 5. On Read error or context cancel, mark host offline, unregister
|
||||
// from the hub.
|
||||
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
|
||||
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
host, ok := authenticateAgent(r, deps.Store)
|
||||
if !ok {
|
||||
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
|
||||
return
|
||||
}
|
||||
|
||||
c := NewConn(host.ID, conn)
|
||||
// Keep agents alive across NAT boxes; coder/websocket
|
||||
// auto-pings under the hood when configured. The default 60s
|
||||
// works fine for a 30s heartbeat cadence.
|
||||
|
||||
runAgentLoop(r.Context(), c, host.ID, deps)
|
||||
})
|
||||
}
|
||||
|
||||
// authenticateAgent returns the host that owns the bearer token in
|
||||
// the request, or (nil, false) if anything is amiss. The same
|
||||
// "false" path is used for missing header, malformed header, unknown
|
||||
// token — no information leak about why.
|
||||
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
|
||||
hdr := r.Header.Get("Authorization")
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(hdr, prefix) {
|
||||
return nil, false
|
||||
}
|
||||
token := strings.TrimPrefix(hdr, prefix)
|
||||
if token == "" {
|
||||
return nil, false
|
||||
}
|
||||
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return h, true
|
||||
}
|
||||
|
||||
// runAgentLoop is the per-connection driver. Returns when the socket
|
||||
// is closed for any reason. It owns the hub registration: register on
|
||||
// hello acceptance, unregister on exit.
|
||||
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
|
||||
// Stage 1: hello (with a tight deadline).
|
||||
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
hello, err := c.Read(helloCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
if hello.Type != api.MsgHello {
|
||||
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
|
||||
return
|
||||
}
|
||||
var helloPayload api.HelloPayload
|
||||
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
|
||||
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
|
||||
return
|
||||
}
|
||||
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
|
||||
c.SendError(ctx, api.ErrProtocolTooOld,
|
||||
fmt.Sprintf("agent protocol_version %d below minimum %d",
|
||||
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
|
||||
"https://restic-manager.example/docs/upgrade")
|
||||
return
|
||||
}
|
||||
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
|
||||
// Forward-compat is fine — newer agents talking to older
|
||||
// servers should accept their lower version. Just log it.
|
||||
slog.Info("ws agent newer than server",
|
||||
"host_id", hostID,
|
||||
"agent_proto", helloPayload.ProtocolVersion,
|
||||
"server_proto", api.CurrentProtocolVersion)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := deps.Store.MarkHostHello(ctx, hostID,
|
||||
helloPayload.AgentVersion, helloPayload.ResticVersion,
|
||||
helloPayload.ProtocolVersion, now); err != nil {
|
||||
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
|
||||
}
|
||||
|
||||
deps.Hub.Register(hostID, c)
|
||||
defer deps.Hub.Unregister(hostID, c)
|
||||
defer func() { _ = c.Close() }()
|
||||
|
||||
slog.Info("ws agent connected",
|
||||
"host_id", hostID,
|
||||
"agent_version", helloPayload.AgentVersion,
|
||||
"protocol_version", helloPayload.ProtocolVersion)
|
||||
|
||||
// Stage 2: main read loop.
|
||||
for {
|
||||
env, err := c.Read(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
dispatchAgentMessage(ctx, c, hostID, env, deps)
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchAgentMessage routes a single envelope to its handler. Only
|
||||
// hello + heartbeat are wired up in Phase 1's first slice; the rest
|
||||
// land with P1-18+ (jobs) and P2 (schedules).
|
||||
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
|
||||
switch env.Type {
|
||||
case api.MsgHeartbeat:
|
||||
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
|
||||
|
||||
case api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished,
|
||||
api.MsgLogStream, api.MsgSnapshotsRpt, api.MsgRepoStats,
|
||||
api.MsgScheduleAck, api.MsgCommandResult:
|
||||
// TODO(P1-18+): persist + fan out to subscribed browsers.
|
||||
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
|
||||
|
||||
case api.MsgError:
|
||||
var ep api.ErrorPayload
|
||||
_ = env.UnmarshalPayload(&ep)
|
||||
slog.Warn("ws agent reported error", "host_id", hostID,
|
||||
"code", string(ep.Code), "message", ep.Message)
|
||||
|
||||
default:
|
||||
slog.Warn("ws unknown message type from agent",
|
||||
"type", env.Type, "host_id", hostID)
|
||||
}
|
||||
}
|
||||
|
||||
// MinHeartbeatInterval is a sanity floor — any agent reporting
|
||||
// heartbeats more often than this is misbehaving. (Spec says 30s.)
|
||||
const MinHeartbeatInterval = 5 * time.Second
|
||||
|
||||
// suppress unused-import false-positives if json drops out later
|
||||
var _ = json.Marshal
|
||||
@@ -0,0 +1,145 @@
|
||||
// Package ws hosts the WebSocket transport for agent ↔ server. The
|
||||
// Hub tracks one active connection per host id; subsequent connections
|
||||
// from the same host evict the prior one (last-write-wins).
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// Hub owns the live agent connections and routes messages.
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
conns map[string]*Conn // hostID → conn
|
||||
}
|
||||
|
||||
// NewHub returns an empty hub.
|
||||
func NewHub() *Hub {
|
||||
return &Hub{conns: make(map[string]*Conn)}
|
||||
}
|
||||
|
||||
// Conn is one agent WS connection. Send is safe for concurrent use;
|
||||
// Read is single-reader (the connection's run loop).
|
||||
type Conn struct {
|
||||
HostID string
|
||||
c *websocket.Conn
|
||||
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
// Register installs c as the canonical connection for hostID. Any
|
||||
// previous connection for that host is closed.
|
||||
func (h *Hub) Register(hostID string, c *Conn) {
|
||||
h.mu.Lock()
|
||||
if prev, ok := h.conns[hostID]; ok {
|
||||
// Best-effort close — a stuck old socket shouldn't block new one.
|
||||
go func(old *Conn) {
|
||||
_ = old.c.Close(websocket.StatusPolicyViolation, "superseded")
|
||||
}(prev)
|
||||
}
|
||||
h.conns[hostID] = c
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Unregister removes c iff it is still the canonical conn (a race
|
||||
// where a newer conn already replaced it must not unregister it).
|
||||
func (h *Hub) Unregister(hostID string, c *Conn) {
|
||||
h.mu.Lock()
|
||||
if cur, ok := h.conns[hostID]; ok && cur == c {
|
||||
delete(h.conns, hostID)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Send delivers an envelope to the host if connected. Returns an error
|
||||
// if the host is offline; caller may queue the message for later.
|
||||
func (h *Hub) Send(ctx context.Context, hostID string, env api.Envelope) error {
|
||||
h.mu.RLock()
|
||||
c, ok := h.conns[hostID]
|
||||
h.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("ws: host %q is offline", hostID)
|
||||
}
|
||||
return c.Send(ctx, env)
|
||||
}
|
||||
|
||||
// Connected reports whether hostID has an active connection.
|
||||
func (h *Hub) Connected(hostID string) bool {
|
||||
h.mu.RLock()
|
||||
_, ok := h.conns[hostID]
|
||||
h.mu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// ----- Conn methods --------------------------------------------------
|
||||
|
||||
// NewConn wraps a freshly-accepted websocket for a given hostID.
|
||||
func NewConn(hostID string, c *websocket.Conn) *Conn {
|
||||
return &Conn{HostID: hostID, c: c}
|
||||
}
|
||||
|
||||
// Send writes an envelope as a JSON text message. Concurrent calls
|
||||
// are serialised; the underlying socket is not safe for parallel
|
||||
// writers.
|
||||
func (c *Conn) Send(ctx context.Context, env api.Envelope) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
raw, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ws: marshal envelope: %w", err)
|
||||
}
|
||||
return c.c.Write(ctx, websocket.MessageText, raw)
|
||||
}
|
||||
|
||||
// SendError writes an error envelope and closes the socket. Used by
|
||||
// the hello handshake when an agent is rejected.
|
||||
func (c *Conn) SendError(ctx context.Context, code api.ErrorCode, msg, helpURL string) {
|
||||
env, err := api.Marshal(api.MsgError, "", api.ErrorPayload{
|
||||
Code: code, Message: msg, HelpURL: helpURL,
|
||||
})
|
||||
if err == nil {
|
||||
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
_ = c.Send(writeCtx, env)
|
||||
}
|
||||
_ = c.c.Close(websocket.StatusPolicyViolation, string(code))
|
||||
}
|
||||
|
||||
// Close shuts the socket down with a normal-closure status code.
|
||||
func (c *Conn) Close() error {
|
||||
return c.c.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
// Read pulls the next JSON envelope off the wire. The caller's
|
||||
// context controls cancellation and timeouts (e.g. read deadlines).
|
||||
func (c *Conn) Read(ctx context.Context) (api.Envelope, error) {
|
||||
mt, raw, err := c.c.Read(ctx)
|
||||
if err != nil {
|
||||
return api.Envelope{}, err
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
return api.Envelope{}, errors.New("ws: expected text frame")
|
||||
}
|
||||
var env api.Envelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
return api.Envelope{}, fmt.Errorf("ws: unmarshal envelope: %w", err)
|
||||
}
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// ----- helpers -------------------------------------------------------
|
||||
|
||||
// LogValue emits a slog-friendly representation of a Conn.
|
||||
func (c *Conn) LogValue() slog.Value {
|
||||
return slog.GroupValue(slog.String("host_id", c.HostID))
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// setupTestHub spins up a Server that exposes only /ws/agent against
|
||||
// a fresh sqlite store with one pre-enrolled host. Returns the URL,
|
||||
// the agent's bearer token, and the host ID.
|
||||
func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
var err error
|
||||
st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
|
||||
hub = NewHub()
|
||||
mux := stdhttp.NewServeMux()
|
||||
mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st}))
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
// Pre-enroll a host directly via store (skipping HTTP).
|
||||
hostID = "01HJ8K70000000000000000000"
|
||||
token, _ = auth.NewToken()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateHost(context.Background(), store.Host{
|
||||
ID: hostID, Name: "h1", OS: "linux", Arch: "amd64",
|
||||
EnrolledAt: now,
|
||||
}, auth.HashToken(token), ""); err != nil {
|
||||
t.Fatalf("enroll: %v", err)
|
||||
}
|
||||
url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent"
|
||||
return
|
||||
}
|
||||
|
||||
func TestWSHelloAndHeartbeat(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, token, hostID, st, hub := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.CloseNow()
|
||||
|
||||
// Send hello.
|
||||
hello := api.HelloPayload{
|
||||
ProtocolVersion: api.CurrentProtocolVersion,
|
||||
AgentVersion: "0.1.0",
|
||||
ResticVersion: "0.17.1",
|
||||
Hostname: "h1",
|
||||
OS: api.OSLinux,
|
||||
Arch: api.ArchAmd64,
|
||||
}
|
||||
env, _ := api.Marshal(api.MsgHello, "", hello)
|
||||
raw, _ := json.Marshal(env)
|
||||
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
|
||||
t.Fatalf("write hello: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the server to register us (registration happens after
|
||||
// the hello-handler returns; give it up to 1s).
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for !hub.Connected(hostID) && time.Now().Before(deadline) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
if !hub.Connected(hostID) {
|
||||
t.Fatal("host did not register on hub after hello")
|
||||
}
|
||||
|
||||
// Verify host row was marked online + has populated metadata.
|
||||
h, err := st.GetHost(context.Background(), hostID)
|
||||
if err != nil {
|
||||
t.Fatalf("get host: %v", err)
|
||||
}
|
||||
if h.Status != "online" || h.AgentVersion != "0.1.0" {
|
||||
t.Errorf("host after hello: %+v", h)
|
||||
}
|
||||
|
||||
// Send a heartbeat — server should touch last_seen.
|
||||
hb := api.HeartbeatPayload{SentAt: time.Now().UTC()}
|
||||
env, _ = api.Marshal(api.MsgHeartbeat, "", hb)
|
||||
raw, _ = json.Marshal(env)
|
||||
preTouch := h.LastSeenAt
|
||||
_ = c.Write(ctx, websocket.MessageText, raw)
|
||||
|
||||
// Wait briefly for server to process.
|
||||
deadline = time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
h2, _ := st.GetHost(context.Background(), hostID)
|
||||
if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Error("heartbeat did not update last_seen_at")
|
||||
}
|
||||
|
||||
func TestWSRejectsOldProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, token, _, _, _ := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.CloseNow()
|
||||
|
||||
hello := api.HelloPayload{ProtocolVersion: 0} // below minimum
|
||||
env, _ := api.Marshal(api.MsgHello, "", hello)
|
||||
raw, _ := json.Marshal(env)
|
||||
_ = c.Write(ctx, websocket.MessageText, raw)
|
||||
|
||||
// Server should send an error envelope, then close.
|
||||
mt, body, err := c.Read(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
t.Fatalf("frame type: %v", mt)
|
||||
}
|
||||
var got api.Envelope
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.Type != api.MsgError {
|
||||
t.Errorf("expected error envelope, got %q", got.Type)
|
||||
}
|
||||
var ep api.ErrorPayload
|
||||
_ = got.UnmarshalPayload(&ep)
|
||||
if ep.Code != api.ErrProtocolTooOld {
|
||||
t.Errorf("error code: %q", ep.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSRejectsBadToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, _, _, _, _ := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("dial should fail")
|
||||
}
|
||||
if res == nil || res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
if res != nil {
|
||||
t.Errorf("status: %d", res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateHost inserts a new host row. Used by the enrollment flow.
|
||||
// The caller has already minted the host id and hashed the agent
|
||||
// bearer token.
|
||||
func (s *Store) CreateHost(ctx context.Context, h Host, agentTokenHash, certPinSHA256 string) error {
|
||||
tags, err := json.Marshal(h.Tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: marshal tags: %w", err)
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx,
|
||||
`INSERT INTO hosts (
|
||||
id, name, os, arch, agent_version, restic_version, protocol_version,
|
||||
enrolled_at, status, tags,
|
||||
agent_token_hash, cert_pin_sha256
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'offline', ?, ?, ?)`,
|
||||
h.ID, h.Name, h.OS, h.Arch,
|
||||
h.AgentVersion, h.ResticVersion, h.ProtocolVersion,
|
||||
h.EnrolledAt.UTC().Format(time.RFC3339Nano),
|
||||
string(tags),
|
||||
agentTokenHash, certPinSHA256)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create host: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupHostByAgentToken resolves a hashed agent bearer token to the
|
||||
// host it belongs to. Returns ErrNotFound on miss.
|
||||
func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*Host, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version
|
||||
FROM hosts WHERE agent_token_hash = ?`,
|
||||
tokenHash)
|
||||
return scanHost(row)
|
||||
}
|
||||
|
||||
// GetHost returns a host by ID. Returns ErrNotFound on miss.
|
||||
func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version
|
||||
FROM hosts WHERE id = ?`, id)
|
||||
return scanHost(row)
|
||||
}
|
||||
|
||||
// MarkHostHello updates the host row with metadata received in the
|
||||
// agent's hello message and flips status to 'online'.
|
||||
func (s *Store) MarkHostHello(ctx context.Context, id string, agentVersion, resticVersion string, protoVersion int, when time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE hosts
|
||||
SET agent_version = ?, restic_version = ?, protocol_version = ?,
|
||||
last_seen_at = ?, status = 'online'
|
||||
WHERE id = ?`,
|
||||
agentVersion, resticVersion, protoVersion,
|
||||
when.UTC().Format(time.RFC3339Nano), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: mark hello: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TouchHost updates last_seen_at on heartbeat, leaving status alone if
|
||||
// already online (the offline-marker is a separate sweep).
|
||||
func (s *Store) TouchHost(ctx context.Context, id string, when time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE hosts
|
||||
SET last_seen_at = ?,
|
||||
status = CASE WHEN status = 'offline' THEN 'online' ELSE status END
|
||||
WHERE id = ?`,
|
||||
when.UTC().Format(time.RFC3339Nano), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: touch host: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkHostsOfflineStale flips any host that hasn't been seen since
|
||||
// before `cutoff` from 'online' to 'offline'. Returns the number of
|
||||
// rows affected so the caller can log non-zero events.
|
||||
func (s *Store) MarkHostsOfflineStale(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`UPDATE hosts
|
||||
SET status = 'offline'
|
||||
WHERE status = 'online'
|
||||
AND (last_seen_at IS NULL OR last_seen_at < ?)`,
|
||||
cutoff.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: mark offline: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ListHosts returns every host. Phase 1 callers fit a small fleet in
|
||||
// memory; pagination lands when it matters.
|
||||
func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version
|
||||
FROM hosts ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: list hosts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Host
|
||||
for rows.Next() {
|
||||
h, err := scanHostRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, *h)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// ----- scan helpers --------------------------------------------------
|
||||
|
||||
type hostScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanHost(row *sql.Row) (*Host, error) {
|
||||
h, err := scanHostRow(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return h, err
|
||||
}
|
||||
|
||||
func scanHostRow(s hostScanner) (*Host, error) {
|
||||
var h Host
|
||||
var (
|
||||
lastSeen, lastBackupAt sql.NullString
|
||||
repoID, currentJob, lastBkSt sql.NullString
|
||||
enrolled string
|
||||
tags string
|
||||
)
|
||||
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
|
||||
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
|
||||
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
|
||||
¤tJob, &lastBackupAt, &lastBkSt,
|
||||
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
|
||||
&h.AppliedScheduleVersion)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("store: scan host: %w", err)
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339Nano, enrolled)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse enrolled_at: %w", err)
|
||||
}
|
||||
h.EnrolledAt = t
|
||||
if lastSeen.Valid {
|
||||
t, err := time.Parse(time.RFC3339Nano, lastSeen.String)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse last_seen_at: %w", err)
|
||||
}
|
||||
h.LastSeenAt = &t
|
||||
}
|
||||
if lastBackupAt.Valid {
|
||||
t, err := time.Parse(time.RFC3339Nano, lastBackupAt.String)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse last_backup_at: %w", err)
|
||||
}
|
||||
h.LastBackupAt = &t
|
||||
}
|
||||
if repoID.Valid {
|
||||
s := repoID.String
|
||||
h.RepoID = &s
|
||||
}
|
||||
if currentJob.Valid {
|
||||
s := currentJob.String
|
||||
h.CurrentJobID = &s
|
||||
}
|
||||
if lastBkSt.Valid {
|
||||
s := lastBkSt.String
|
||||
h.LastBackupStatus = &s
|
||||
}
|
||||
if tags != "" {
|
||||
_ = json.Unmarshal([]byte(tags), &h.Tags)
|
||||
}
|
||||
return &h, nil
|
||||
}
|
||||
@@ -92,12 +92,15 @@ CREATE INDEX hosts_status ON hosts(status);
|
||||
CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at);
|
||||
|
||||
-- Pending one-time enrollment tokens (TTL'd, single-use).
|
||||
-- consumed_host is audit-only (no FK on purpose: we burn the token
|
||||
-- before the host row exists, and we want this trail to survive a
|
||||
-- later host deletion).
|
||||
CREATE TABLE enrollment_tokens (
|
||||
token_hash TEXT PRIMARY KEY, -- argon2id of token
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
consumed_at TEXT,
|
||||
consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL
|
||||
consumed_host TEXT
|
||||
);
|
||||
CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user