diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 91340b1..1953862 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,32 +2,147 @@ package main import ( "context" + "errors" "flag" "fmt" "log/slog" "os" "os/signal" "syscall" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/agent/config" + "gitea.dcglab.co.uk/steve/restic-manager/internal/agent/sysinfo" + "gitea.dcglab.co.uk/steve/restic-manager/internal/agent/wsclient" + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" ) var version = "dev" func main() { + if err := run(); err != nil { + slog.Error("agent fatal", "err", err) + os.Exit(1) + } +} + +func run() error { + configPath := flag.String("config", config.DefaultPath(), "path to agent.yaml") + enrollServer := flag.String("enroll-server", "", "server URL (used with -enroll-token to perform first-run enrollment)") + enrollToken := flag.String("enroll-token", "", "one-time enrollment token (operator copies this from the UI)") showVersion := flag.Bool("version", false, "print version and exit") flag.Parse() if *showVersion { fmt.Println("restic-manager-agent", version) - return + return nil } logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) slog.SetDefault(logger) + cfg, err := config.Load(*configPath) + if err != nil { + return fmt.Errorf("config: %w", err) + } + + // Enrollment mode: agent was started with -enroll-server -enroll-token. + // On success we persist the credentials and exit (the install script + // then starts the agent service). Avoiding a long-running process here + // keeps the enrollment story restartable. + if *enrollToken != "" { + if *enrollServer == "" { + return errors.New("enrollment: -enroll-server is required with -enroll-token") + } + return doEnroll(*enrollServer, *enrollToken, cfg, version) + } + + if !cfg.Enrolled() { + return fmt.Errorf("agent is not enrolled; run with -enroll-server and -enroll-token first (config %q)", *configPath) + } + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - slog.Info("restic-manager agent starting", "version", version) - <-ctx.Done() - slog.Info("shutting down") + snap, err := sysinfo.Collect(ctx, cfg.ResticPath) + if err != nil { + return fmt.Errorf("sysinfo: %w", err) + } + slog.Info("agent starting", + "version", version, + "host_id", cfg.HostID, + "server", cfg.ServerURL, + "restic_version", snap.ResticVersion, + "protocol_version", snap.ProtocolVersion, + ) + + wsCfg := wsclient.Config{ + ServerURL: cfg.ServerURL, + AgentToken: cfg.AgentToken, + HostID: cfg.HostID, + CertPinSHA256: cfg.CertPinSHA256, + HelloPayload: api.HelloPayload{ + ProtocolVersion: snap.ProtocolVersion, + AgentVersion: version, + ResticVersion: snap.ResticVersion, + Hostname: snap.Hostname, + OS: snap.OS, + Arch: snap.Arch, + }, + } + + if err := wsclient.Run(ctx, wsCfg, dispatch); err != nil { + return fmt.Errorf("ws run: %w", err) + } + slog.Info("agent shutting down") + return nil +} + +// dispatch handles server-pushed envelopes. Phase 1's first slice +// just logs; P1-19/20/21 wire command.run to the runner. +func dispatch(_ context.Context, env api.Envelope) error { + switch env.Type { + case api.MsgCommandRun: + slog.Info("ws agent: command.run received (not yet implemented)", "id", env.ID) + case api.MsgCommandCancel: + slog.Info("ws agent: command.cancel received (not yet implemented)", "id", env.ID) + case api.MsgScheduleSet: + slog.Info("ws agent: schedule.set received (not yet implemented)", "id", env.ID) + case api.MsgConfigUpdate: + slog.Info("ws agent: config.update received (not yet implemented)", "id", env.ID) + case api.MsgAgentUpdateAvail: + slog.Info("ws agent: agent.update.available received (not yet implemented)", "id", env.ID) + default: + slog.Debug("ws agent: ignored message", "type", env.Type) + } + return nil +} + +func doEnroll(serverURL, token string, cfg *config.Config, agentVersion string) error { + ctx, cancel := context.WithTimeout(context.Background(), 60*1e9) + defer cancel() + + snap, err := sysinfo.Collect(ctx, cfg.ResticPath) + if err != nil { + return fmt.Errorf("sysinfo: %w", err) + } + res, err := wsclient.Enroll(ctx, serverURL, wsclient.EnrollRequest{ + Token: token, + HostName: snap.Hostname, + OS: snap.OS, + Arch: snap.Arch, + AgentVersion: agentVersion, + ResticVersion: snap.ResticVersion, + }) + if err != nil { + return fmt.Errorf("enroll: %w", err) + } + cfg.ServerURL = serverURL + cfg.HostID = res.HostID + cfg.AgentToken = res.AgentToken + cfg.CertPinSHA256 = res.CertPinSHA256 + if err := cfg.Save(); err != nil { + return fmt.Errorf("save config: %w", err) + } + fmt.Fprintf(os.Stderr, "enrolled as host %s on %s\n", res.HostID, serverURL) + return nil } diff --git a/cmd/server/main.go b/cmd/server/main.go index bc6d2b0..a004afb 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -14,8 +14,9 @@ import ( "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" - rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" + rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) @@ -76,10 +77,13 @@ func run() error { } defer func() { _ = st.Close() }() + hub := ws.NewHub() + deps := rmhttp.Deps{ Cfg: cfg, Store: st, AEAD: aead, + Hub: hub, } // First-run bootstrap: if the users table is empty, mint a one-time @@ -117,21 +121,31 @@ func run() error { errCh <- srv.Start() }() - // Background sweeper for expired sessions and enrollment tokens. - tick := time.NewTicker(15 * time.Minute) - defer tick.Stop() + // Background sweepers: + // - sessions/tokens purge: 15 min + // - host offline-after-90s mark: every 30s (matches heartbeat + // cadence — agent sends every 30s, P1-12) + purgeTick := time.NewTicker(15 * time.Minute) + defer purgeTick.Stop() + offlineTick := time.NewTicker(30 * time.Second) + defer offlineTick.Stop() go func() { for { select { case <-ctx.Done(): return - case <-tick.C: + case <-purgeTick.C: if n, err := st.PurgeExpiredSessions(ctx); err == nil && n > 0 { slog.Info("purged expired sessions", "n", n) } if n, err := st.PurgeExpiredEnrollmentTokens(ctx); err == nil && n > 0 { slog.Info("purged expired enrollment tokens", "n", n) } + case <-offlineTick.C: + cutoff := time.Now().Add(-90 * time.Second) + if n, err := st.MarkHostsOfflineStale(ctx, cutoff); err == nil && n > 0 { + slog.Info("marked hosts offline (stale heartbeat)", "n", n) + } } } }() diff --git a/go.mod b/go.mod index 0c283ac..9af0d28 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( ) require ( + github.com/coder/websocket v1.8.14 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index 18dc058..e241a4c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= diff --git a/internal/agent/config/config.go b/internal/agent/config/config.go new file mode 100644 index 0000000..48bba00 --- /dev/null +++ b/internal/agent/config/config.go @@ -0,0 +1,110 @@ +// Package config loads the agent's persistent configuration. After +// enrollment, the file holds the bearer token + server URL; it is +// only ever written via Save (which replaces atomically). +package config + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +// Config is the on-disk shape of the agent's config file. +type Config struct { + // ServerURL is the base URL of the control plane, e.g. + // https://restic.lab.example. The agent appends /ws/agent and + // /api/agents/enroll. + ServerURL string `yaml:"server_url"` + + // AgentToken is the bearer credential issued at enrollment. + // Empty means "not yet enrolled." + AgentToken string `yaml:"agent_token"` + + // HostID is what the server thinks this host is. + HostID string `yaml:"host_id"` + + // CertPinSHA256 (optional) is the SHA-256 of the server's TLS + // cert. When set, the agent refuses to connect to a server + // whose cert hash doesn't match. + CertPinSHA256 string `yaml:"cert_pin_sha256,omitempty"` + + // ResticPath overrides the auto-detected restic binary path. + ResticPath string `yaml:"restic_path,omitempty"` + + // path is the file we loaded from. Used by Save. + path string `yaml:"-"` +} + +// DefaultPath returns the canonical config path for the current OS. +// Phase 1 ships Linux only; Windows path lives in the spec for P2. +func DefaultPath() string { + return "/etc/restic-manager/agent.yaml" +} + +// Load reads and parses the config file at path. A missing file is +// returned as an empty Config (not an error) — first-run agents +// haven't been enrolled yet. +func Load(path string) (*Config, error) { + c := &Config{path: path} + body, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return c, nil + } + return nil, fmt.Errorf("agent config: read %q: %w", path, err) + } + if err := yaml.Unmarshal(body, c); err != nil { + return nil, fmt.Errorf("agent config: parse %q: %w", path, err) + } + c.path = path + return c, nil +} + +// Save writes the config back atomically: write to .tmp, fsync, +// rename. A crash mid-write either leaves the old file or the new one, +// never a half-written one. +func (c *Config) Save() error { + if c.path == "" { + return fmt.Errorf("agent config: no path set") + } + dir := filepath.Dir(c.path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("agent config: mkdir %q: %w", dir, err) + } + body, err := yaml.Marshal(c) + if err != nil { + return fmt.Errorf("agent config: marshal: %w", err) + } + + tmp := c.path + ".tmp" + f, err := os.OpenFile(tmp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("agent config: create tmp: %w", err) + } + if _, err := f.Write(body); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("agent config: write tmp: %w", err) + } + if err := f.Sync(); err != nil { + _ = f.Close() + _ = os.Remove(tmp) + return fmt.Errorf("agent config: fsync tmp: %w", err) + } + if err := f.Close(); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("agent config: close tmp: %w", err) + } + if err := os.Rename(tmp, c.path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("agent config: rename: %w", err) + } + return nil +} + +// Enrolled reports whether the agent has finished enrollment. +func (c *Config) Enrolled() bool { + return c.AgentToken != "" && c.HostID != "" && c.ServerURL != "" +} diff --git a/internal/agent/sysinfo/sysinfo.go b/internal/agent/sysinfo/sysinfo.go new file mode 100644 index 0000000..c4b9c62 --- /dev/null +++ b/internal/agent/sysinfo/sysinfo.go @@ -0,0 +1,78 @@ +// Package sysinfo collects host metadata at agent startup: OS, arch, +// hostname, restic version. The agent sends this in `hello` so the +// server's Host row stays current. +package sysinfo + +import ( + "context" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// Snapshot is the bundle of metadata reported in `hello`. +type Snapshot struct { + Hostname string + OS api.HostOS + Arch api.HostArch + ResticVersion string + ProtocolVersion int + BootTime time.Time +} + +// Collect probes the running host. resticPath, if non-empty, +// overrides PATH lookup. +func Collect(ctx context.Context, resticPath string) (Snapshot, error) { + hn, err := os.Hostname() + if err != nil { + return Snapshot{}, fmt.Errorf("sysinfo: hostname: %w", err) + } + + osTag := api.HostOS(runtime.GOOS) + archTag := api.HostArch(runtime.GOARCH) + + resticVer, _ := detectResticVersion(ctx, resticPath) // empty on failure is fine + return Snapshot{ + Hostname: hn, + OS: osTag, + Arch: archTag, + ResticVersion: resticVer, + ProtocolVersion: api.CurrentProtocolVersion, + }, nil +} + +// detectResticVersion runs `restic version` and parses the first line. +// Output looks like: +// restic 0.17.1 compiled with go1.22.5 on linux/amd64 +// Returns the version token (e.g. "0.17.1") or "" if restic isn't +// found. We never block startup on a missing restic — the operator +// might not have installed it yet, and the agent should still be +// able to connect and report. +func detectResticVersion(ctx context.Context, override string) (string, error) { + bin := override + if bin == "" { + var err error + bin, err = exec.LookPath("restic") + if err != nil { + return "", err + } + } + + versionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + out, err := exec.CommandContext(versionCtx, bin, "version").Output() + if err != nil { + return "", err + } + first := strings.SplitN(strings.TrimSpace(string(out)), "\n", 2)[0] + parts := strings.Fields(first) + if len(parts) >= 2 && parts[0] == "restic" { + return parts[1], nil + } + return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first) +} diff --git a/internal/agent/wsclient/client.go b/internal/agent/wsclient/client.go new file mode 100644 index 0000000..f1bd81d --- /dev/null +++ b/internal/agent/wsclient/client.go @@ -0,0 +1,246 @@ +// Package wsclient is the agent's outbound WebSocket connection to +// the control plane: dial with bearer auth, perform the hello +// handshake, send heartbeats, dispatch server-pushed commands. +// +// The Run loop is a forever-loop with exponential backoff on dial +// failures, capped at 60s. Disconnected agents keep retrying. +package wsclient + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "log/slog" + "math/rand" + stdhttp "net/http" + "net/url" + "strings" + "time" + + "github.com/coder/websocket" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// Config holds the agent's connection settings. +type Config struct { + ServerURL string + AgentToken string + HostID string + CertPinSHA256 string // hex; empty disables pinning + HeartbeatPeriod time.Duration + HelloPayload api.HelloPayload +} + +// Handler is invoked for every server-sent message. The agent's main +// program supplies one that knows how to dispatch command.run etc. +// to the runner package. +type Handler func(ctx context.Context, env api.Envelope) error + +// Run keeps the agent connected indefinitely. Returns when ctx is +// cancelled. Errors during a single connection attempt are logged and +// trigger reconnect-with-backoff; only ctx.Done() ends the loop. +func Run(ctx context.Context, cfg Config, handle Handler) error { + if cfg.HeartbeatPeriod <= 0 { + cfg.HeartbeatPeriod = 30 * time.Second + } + + backoff := newBackoff(time.Second, 60*time.Second) + for { + err := connectOnce(ctx, cfg, handle) + if errors.Is(err, context.Canceled) { + return nil + } + if err != nil { + slog.Warn("ws agent disconnect", "err", err) + } + if err := sleepCtx(ctx, backoff.next()); err != nil { + return nil + } + } +} + +// connectOnce performs one full connection lifecycle: dial → hello → +// heartbeat loop + read loop → close. Returns when either side closes +// the socket. +func connectOnce(ctx context.Context, cfg Config, handle Handler) error { + wsURL, err := buildWSURL(cfg.ServerURL) + if err != nil { + return fmt.Errorf("ws agent: bad server url: %w", err) + } + + dialOpts := &websocket.DialOptions{ + HTTPHeader: stdhttp.Header{ + "Authorization": []string{"Bearer " + cfg.AgentToken}, + }, + } + if cfg.CertPinSHA256 != "" && strings.HasPrefix(wsURL, "wss") { + dialOpts.HTTPClient = &stdhttp.Client{ + Transport: &stdhttp.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + VerifyPeerCertificate: pinChecker(cfg.CertPinSHA256), + }, + }, + } + } + + dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + conn, _, err := websocket.Dial(dialCtx, wsURL, dialOpts) + cancel() + if err != nil { + return fmt.Errorf("dial: %w", err) + } + defer conn.CloseNow() //nolint:errcheck + + // Send hello. + helloEnv, err := api.Marshal(api.MsgHello, "", cfg.HelloPayload) + if err != nil { + return fmt.Errorf("marshal hello: %w", err) + } + if err := writeEnv(ctx, conn, helloEnv); err != nil { + return fmt.Errorf("write hello: %w", err) + } + slog.Info("ws agent connected", "server", wsURL) + + // Heartbeat goroutine. + heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx) + defer cancelHeartbeat() + go heartbeatLoop(heartbeatCtx, conn, cfg.HeartbeatPeriod) + + // Read loop. A read error returns and closes the conn. + for { + mt, raw, err := conn.Read(ctx) + if err != nil { + return fmt.Errorf("read: %w", err) + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + if err := json.Unmarshal(raw, &env); err != nil { + slog.Warn("ws agent: bad envelope from server", "err", err) + continue + } + if env.Type == api.MsgError { + var ep api.ErrorPayload + _ = env.UnmarshalPayload(&ep) + slog.Error("ws agent: server reported error", + "code", ep.Code, "message", ep.Message, "help", ep.HelpURL) + // protocol_too_old is fatal — keep retrying won't help. + if ep.Code == api.ErrProtocolTooOld { + return fmt.Errorf("protocol too old: %s", ep.Message) + } + continue + } + if handle != nil { + if err := handle(ctx, env); err != nil { + slog.Warn("ws agent: handler returned error", "type", env.Type, "err", err) + } + } + } +} + +func heartbeatLoop(ctx context.Context, conn *websocket.Conn, period time.Duration) { + t := time.NewTicker(period) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + env, err := api.Marshal(api.MsgHeartbeat, "", + api.HeartbeatPayload{SentAt: time.Now().UTC()}) + if err != nil { + continue + } + if err := writeEnv(ctx, conn, env); err != nil { + slog.Warn("ws agent: heartbeat write failed", "err", err) + return + } + } + } +} + +func writeEnv(ctx context.Context, conn *websocket.Conn, env api.Envelope) error { + raw, err := json.Marshal(env) + if err != nil { + return err + } + return conn.Write(ctx, websocket.MessageText, raw) +} + +func buildWSURL(serverURL string) (string, error) { + u, err := url.Parse(serverURL) + if err != nil { + return "", err + } + switch u.Scheme { + case "https": + u.Scheme = "wss" + case "http": + u.Scheme = "ws" + case "ws", "wss": + // already correct + default: + return "", fmt.Errorf("unsupported scheme %q", u.Scheme) + } + u.Path = strings.TrimRight(u.Path, "/") + "/ws/agent" + return u.String(), nil +} + +// pinChecker returns a VerifyPeerCertificate callback that requires +// the leaf cert's SHA-256 to match wantHex. We do this *in addition* +// to the OS root verification (we don't replace it). +func pinChecker(wantHex string) func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return errors.New("ws agent: no peer certs") + } + got := sha256Hex(rawCerts[0]) + if got != wantHex { + return fmt.Errorf("ws agent: cert pin mismatch (got %s want %s)", got, wantHex) + } + return nil + } +} + +func sha256Hex(b []byte) string { + // avoid pulling in crypto/sha256 in this top-level file twice; + // indirection through hex-encode is the classic shape. + h := newSHA256() + h.Write(b) + return hex.EncodeToString(h.Sum(nil)) +} + +// ----- backoff ------------------------------------------------------- + +type backoff struct { + cur, max time.Duration +} + +func newBackoff(base, max time.Duration) *backoff { return &backoff{cur: base, max: max} } + +func (b *backoff) next() time.Duration { + d := b.cur + // 20% jitter, deterministic-enough randomness. + jitter := time.Duration(rand.Int63n(int64(d) / 5)) //nolint:gosec + b.cur *= 2 + if b.cur > b.max { + b.cur = b.max + } + return d + jitter +} + +func sleepCtx(ctx context.Context, d time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d): + return nil + } +} diff --git a/internal/agent/wsclient/enroll.go b/internal/agent/wsclient/enroll.go new file mode 100644 index 0000000..6ba00bd --- /dev/null +++ b/internal/agent/wsclient/enroll.go @@ -0,0 +1,67 @@ +package wsclient + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + stdhttp "net/http" + "strings" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// EnrollRequest is what we POST to /api/agents/enroll. +type EnrollRequest struct { + Token string `json:"token"` + HostName string `json:"hostname"` + OS api.HostOS `json:"os"` + Arch api.HostArch `json:"arch"` + AgentVersion string `json:"agent_version"` + ResticVersion string `json:"restic_version"` +} + +// EnrollResponse is what the server hands back. +type EnrollResponse struct { + HostID string `json:"host_id"` + AgentToken string `json:"agent_token"` + CertPinSHA256 string `json:"cert_pin_sha256,omitempty"` +} + +// Enroll exchanges a one-time enrollment token for persistent agent +// credentials. Called by the install script on first run. +func Enroll(ctx context.Context, serverURL string, req EnrollRequest) (*EnrollResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("agent enroll: marshal: %w", err) + } + postURL := strings.TrimRight(serverURL, "/") + "/api/agents/enroll" + + httpReq, err := stdhttp.NewRequestWithContext(ctx, stdhttp.MethodPost, postURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("agent enroll: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + client := &stdhttp.Client{Timeout: 30 * time.Second} + res, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("agent enroll: post: %w", err) + } + defer res.Body.Close() + rawRes, _ := io.ReadAll(res.Body) + if res.StatusCode != stdhttp.StatusCreated { + return nil, fmt.Errorf("agent enroll: server returned %d: %s", + res.StatusCode, rawRes) + } + var er EnrollResponse + if err := json.Unmarshal(rawRes, &er); err != nil { + return nil, fmt.Errorf("agent enroll: parse response: %w", err) + } + if er.AgentToken == "" || er.HostID == "" { + return nil, fmt.Errorf("agent enroll: incomplete response: %+v", er) + } + return &er, nil +} diff --git a/internal/agent/wsclient/sha256.go b/internal/agent/wsclient/sha256.go new file mode 100644 index 0000000..0d86a11 --- /dev/null +++ b/internal/agent/wsclient/sha256.go @@ -0,0 +1,8 @@ +package wsclient + +import ( + "crypto/sha256" + "hash" +) + +func newSHA256() hash.Hash { return sha256.New() } diff --git a/internal/server/http/enrollment.go b/internal/server/http/enrollment.go new file mode 100644 index 0000000..0ea6ce7 --- /dev/null +++ b/internal/server/http/enrollment.go @@ -0,0 +1,165 @@ +package http + +import ( + "encoding/json" + stdhttp "net/http" + "strings" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// enrollRequest is the body posted by the agent installer. The token +// was issued by the operator via the UI ("Add host" → P1-27); the +// host metadata comes from the agent's own sysinfo collection. +type enrollRequest struct { + Token string `json:"token"` + HostName string `json:"hostname"` + OS api.HostOS `json:"os"` + Arch api.HostArch `json:"arch"` + AgentVersion string `json:"agent_version"` + ResticVersion string `json:"restic_version"` +} + +// enrollResponse hands the agent the credentials it'll use forever. +// AgentToken is shown exactly once; the server stores its hash. +// CertPinSHA256 is the SHA-256 of the server's certificate, base64; +// the agent pins this on every reconnect so a stolen DB at the +// control plane can't be replayed against an attacker's TLS endpoint. +type enrollResponse struct { + HostID string `json:"host_id"` + AgentToken string `json:"agent_token"` + CertPinSHA256 string `json:"cert_pin_sha256,omitempty"` +} + +// enrollOperatorRequest creates a one-time enrollment token for an +// operator who is about to install an agent. Authenticated UI route. +type enrollOperatorRequest struct { + HostName string `json:"hostname"` + Tags []string `json:"tags,omitempty"` +} + +type enrollOperatorResponse struct { + Token string `json:"token"` + ExpiresAt time.Time `json:"expires_at"` +} + +// handleAgentEnroll consumes a one-time token, persists a Host row, +// and returns persistent agent credentials. Open endpoint (no +// session) — the token is the credential. +func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request) { + var req enrollRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) + return + } + if req.Token == "" || req.HostName == "" || req.OS == "" || req.Arch == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_field", + "token, hostname, os, arch all required") + return + } + + hostID := ulid.Make().String() + + // Atomically: validate + consume token, then create the host. + // We do these in two statements; if create-host fails, the token + // is already burned. That's acceptable — operator just regens. + tokHash := auth.HashToken(req.Token) + if err := s.deps.Store.ConsumeEnrollmentToken(r.Context(), tokHash, hostID); err != nil { + writeJSONError(w, stdhttp.StatusUnauthorized, "invalid_token", + "token unknown, expired, or already used") + return + } + + // Mint the persistent agent bearer. + agentToken, err := auth.NewToken() + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + + host := store.Host{ + ID: hostID, + Name: strings.TrimSpace(req.HostName), + OS: string(req.OS), + Arch: string(req.Arch), + AgentVersion: req.AgentVersion, + ResticVersion: req.ResticVersion, + EnrolledAt: time.Now().UTC(), + } + if err := s.deps.Store.CreateHost(r.Context(), host, + auth.HashToken(agentToken), ""); err != nil { + writeJSONError(w, stdhttp.StatusConflict, "host_exists", err.Error()) + return + } + + _ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), + Actor: "system", + Action: "host.enrolled", + TargetKind: ptr("host"), + TargetID: &hostID, + TS: host.EnrolledAt, + }) + + writeJSON(w, stdhttp.StatusCreated, enrollResponse{ + HostID: hostID, + AgentToken: agentToken, + // CertPinSHA256 is populated by a TLS-aware future revision. + // For now (HTTP-or-TLS-by-Caddy) we leave it empty and rely + // on the agent trusting its OS root store. + }) +} + +// handleCreateEnrollmentToken (operator-facing) — generates a +// short-lived token for a new host. Authenticated; admin/operator only. +// +// TODO: gate by authn middleware once login session lookup lands. +// For Phase 1's first slice, we accept the bootstrap-shipped admin +// session cookie and trust it, validating the cookie via store. +func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhttp.Request) { + if !s.authedUser(r) { + writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "") + return + } + + var req enrollOperatorRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) + return + } + + token, err := auth.NewToken() + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + const ttl = time.Hour + if err := s.deps.Store.CreateEnrollmentToken(r.Context(), auth.HashToken(token), ttl); err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + + writeJSON(w, stdhttp.StatusCreated, enrollOperatorResponse{ + Token: token, + ExpiresAt: time.Now().Add(ttl).UTC(), + }) +} + +// authedUser returns true iff the request carries a valid session +// cookie. Minimal stub for now; full RBAC middleware lands with +// P4-03. +func (s *Server) authedUser(r *stdhttp.Request) bool { + c, err := r.Cookie(sessionCookieName) + if err != nil { + return false + } + _, err = s.deps.Store.LookupSession(r.Context(), auth.HashToken(c.Value)) + return err == nil +} + +func ptr(s string) *string { return &s } diff --git a/internal/server/http/enrollment_test.go b/internal/server/http/enrollment_test.go new file mode 100644 index 0000000..cc71604 --- /dev/null +++ b/internal/server/http/enrollment_test.go @@ -0,0 +1,118 @@ +package http + +import ( + "bytes" + "context" + "encoding/json" + "io" + stdhttp "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" + "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// newTestServerWithHub mirrors newTestServer but plugs in a real +// ws.Hub so /ws/agent is available. +func newTestServerWithHub(t *testing.T) (*Server, string, *store.Store) { + t.Helper() + dir := t.TempDir() + st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + keyPath := filepath.Join(dir, "secret.key") + _ = crypto.GenerateKeyFile(keyPath) + key, _ := crypto.LoadKeyFromFile(keyPath) + aead, _ := crypto.NewAEAD(key) + + deps := Deps{ + Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath}, + Store: st, + AEAD: aead, + Hub: ws.NewHub(), + } + s := New(deps) + ts := httptest.NewServer(s.srv.Handler) + t.Cleanup(ts.Close) + return s, ts.URL, st +} + +func TestEnrollmentBadToken(t *testing.T) { + t.Parallel() + _, url, _ := newTestServerWithHub(t) + + body, _ := json.Marshal(enrollRequest{ + Token: "no-such-token", HostName: "host1", + OS: api.OSLinux, Arch: api.ArchAmd64, + AgentVersion: "0.1", ResticVersion: "0.17", + }) + res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("post: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusUnauthorized { + t.Errorf("status: %d", res.StatusCode) + } +} + +func TestEnrollmentHappyPath(t *testing.T) { + t.Parallel() + _, url, st := newTestServerWithHub(t) + + // Issue a token directly via the store (skipping the operator UI). + rawToken, _ := auth.NewToken() + if err := st.CreateEnrollmentToken(context.Background(), + auth.HashToken(rawToken), 5*time.Minute); err != nil { + t.Fatalf("issue: %v", err) + } + + body, _ := json.Marshal(enrollRequest{ + Token: rawToken, HostName: "test-host", + OS: api.OSLinux, Arch: api.ArchAmd64, + AgentVersion: "0.1", ResticVersion: "0.17", + }) + res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("post: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusCreated { + buf, _ := io.ReadAll(res.Body) + t.Fatalf("status %d: %s", res.StatusCode, buf) + } + + var er enrollResponse + if err := json.NewDecoder(res.Body).Decode(&er); err != nil { + t.Fatalf("decode: %v", err) + } + if er.HostID == "" || er.AgentToken == "" { + t.Errorf("missing fields in response: %+v", er) + } + + // Token must not be reusable. + res2, _ := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) + defer res2.Body.Close() + if res2.StatusCode != stdhttp.StatusUnauthorized { + t.Errorf("re-enrollment with same token should fail, got %d", res2.StatusCode) + } + + // Host row exists with matching agent_token_hash. + got, err := st.LookupHostByAgentToken(context.Background(), auth.HashToken(er.AgentToken)) + if err != nil { + t.Fatalf("lookup by token: %v", err) + } + if got.Name != "test-host" || got.OS != "linux" { + t.Errorf("host fields: %+v", got) + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index 2494a03..c1a6ff2 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -15,6 +15,7 @@ import ( "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) @@ -24,6 +25,7 @@ type Deps struct { Cfg config.Config Store *store.Store AEAD *crypto.AEAD + Hub *ws.Hub // BootstrapToken (optional, populated only on first run) is the raw // admin-bootstrap token printed in the server logs. While set, the // /bootstrap endpoint accepts it to create the first admin user. @@ -73,8 +75,24 @@ func (s *Server) routes(r chi.Router) { r.Post("/auth/login", s.handleLogin) r.Post("/auth/logout", s.handleLogout) r.Post("/bootstrap", s.handleBootstrap) + + // Agent enrollment (open endpoint — token is the credential). + r.Post("/agents/enroll", s.handleAgentEnroll) + + // Operator → server (authenticated). Spec.md §6.1's + // /hosts/{id}/enrollment-token (regenerate) lands when the + // host page can call it; for now just the create endpoint. + r.Post("/enrollment-tokens", s.handleCreateEnrollmentToken) }) + // Agent ↔ server WebSocket. Bearer-authenticated inside the handler. + if s.deps.Hub != nil { + r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{ + Hub: s.deps.Hub, + Store: s.deps.Store, + })) + } + // UI handlers will hang off / — Phase 1 will add them. r.Get("/", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) { _, _ = fmt.Fprint(w, "restic-manager — UI not yet implemented") diff --git a/internal/server/ws/doc.go b/internal/server/ws/doc.go deleted file mode 100644 index 2c0b2bc..0000000 --- a/internal/server/ws/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package ws hosts the WebSocket transport for agent ↔ server and the -// browser-facing live job log stream. -package ws diff --git a/internal/server/ws/handler.go b/internal/server/ws/handler.go new file mode 100644 index 0000000..8828953 --- /dev/null +++ b/internal/server/ws/handler.go @@ -0,0 +1,183 @@ +package ws + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + stdhttp "net/http" + "strings" + "time" + + "github.com/coder/websocket" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// HandlerDeps is the set of collaborators the agent WS handler needs. +type HandlerDeps struct { + Hub *Hub + Store *store.Store +} + +// AgentHandler is the http.Handler that owns /ws/agent. Agents +// authenticate with `Authorization: Bearer ` (issued at +// enrollment) before the WS upgrade. +// +// Lifecycle: +// 1. Bearer token resolves to a Host row. +// 2. Upgrade. +// 3. First message must be `hello`; protocol_version checked here. +// 4. Loop: read messages, dispatch by type. Heartbeats touch the +// host row; job/log/repo messages forward to the relevant +// handlers (TODO: lands with P1-18 onward). +// 5. On Read error or context cancel, mark host offline, unregister +// from the hub. +func AgentHandler(deps HandlerDeps) stdhttp.Handler { + return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + host, ok := authenticateAgent(r, deps.Store) + if !ok { + stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized) + return + } + + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI. + }) + if err != nil { + slog.Warn("ws accept failed", "err", err, "host_id", host.ID) + return + } + + c := NewConn(host.ID, conn) + // Keep agents alive across NAT boxes; coder/websocket + // auto-pings under the hood when configured. The default 60s + // works fine for a 30s heartbeat cadence. + + runAgentLoop(r.Context(), c, host.ID, deps) + }) +} + +// authenticateAgent returns the host that owns the bearer token in +// the request, or (nil, false) if anything is amiss. The same +// "false" path is used for missing header, malformed header, unknown +// token — no information leak about why. +func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) { + hdr := r.Header.Get("Authorization") + const prefix = "Bearer " + if !strings.HasPrefix(hdr, prefix) { + return nil, false + } + token := strings.TrimPrefix(hdr, prefix) + if token == "" { + return nil, false + } + h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token)) + if err != nil { + return nil, false + } + return h, true +} + +// runAgentLoop is the per-connection driver. Returns when the socket +// is closed for any reason. It owns the hub registration: register on +// hello acceptance, unregister on exit. +func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) { + // Stage 1: hello (with a tight deadline). + helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + hello, err := c.Read(helloCtx) + cancel() + if err != nil { + slog.Info("ws hello read failed", "host_id", hostID, "err", err) + _ = c.Close() + return + } + if hello.Type != api.MsgHello { + c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "") + return + } + var helloPayload api.HelloPayload + if err := hello.UnmarshalPayload(&helloPayload); err != nil { + c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "") + return + } + if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion { + c.SendError(ctx, api.ErrProtocolTooOld, + fmt.Sprintf("agent protocol_version %d below minimum %d", + helloPayload.ProtocolVersion, api.MinAgentProtocolVersion), + "https://restic-manager.example/docs/upgrade") + return + } + if helloPayload.ProtocolVersion > api.CurrentProtocolVersion { + // Forward-compat is fine — newer agents talking to older + // servers should accept their lower version. Just log it. + slog.Info("ws agent newer than server", + "host_id", hostID, + "agent_proto", helloPayload.ProtocolVersion, + "server_proto", api.CurrentProtocolVersion) + } + + now := time.Now().UTC() + if err := deps.Store.MarkHostHello(ctx, hostID, + helloPayload.AgentVersion, helloPayload.ResticVersion, + helloPayload.ProtocolVersion, now); err != nil { + slog.Error("ws mark host hello failed", "host_id", hostID, "err", err) + } + + deps.Hub.Register(hostID, c) + defer deps.Hub.Unregister(hostID, c) + defer func() { _ = c.Close() }() + + slog.Info("ws agent connected", + "host_id", hostID, + "agent_version", helloPayload.AgentVersion, + "protocol_version", helloPayload.ProtocolVersion) + + // Stage 2: main read loop. + for { + env, err := c.Read(ctx) + if err != nil { + if !errors.Is(err, context.Canceled) { + slog.Info("ws agent read loop ended", "host_id", hostID, "err", err) + } + return + } + dispatchAgentMessage(ctx, c, hostID, env, deps) + } +} + +// dispatchAgentMessage routes a single envelope to its handler. Only +// hello + heartbeat are wired up in Phase 1's first slice; the rest +// land with P1-18+ (jobs) and P2 (schedules). +func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) { + switch env.Type { + case api.MsgHeartbeat: + _ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC()) + + case api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished, + api.MsgLogStream, api.MsgSnapshotsRpt, api.MsgRepoStats, + api.MsgScheduleAck, api.MsgCommandResult: + // TODO(P1-18+): persist + fan out to subscribed browsers. + slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID) + + case api.MsgError: + var ep api.ErrorPayload + _ = env.UnmarshalPayload(&ep) + slog.Warn("ws agent reported error", "host_id", hostID, + "code", string(ep.Code), "message", ep.Message) + + default: + slog.Warn("ws unknown message type from agent", + "type", env.Type, "host_id", hostID) + } +} + +// MinHeartbeatInterval is a sanity floor — any agent reporting +// heartbeats more often than this is misbehaving. (Spec says 30s.) +const MinHeartbeatInterval = 5 * time.Second + +// suppress unused-import false-positives if json drops out later +var _ = json.Marshal diff --git a/internal/server/ws/hub.go b/internal/server/ws/hub.go new file mode 100644 index 0000000..c0b64ef --- /dev/null +++ b/internal/server/ws/hub.go @@ -0,0 +1,145 @@ +// Package ws hosts the WebSocket transport for agent ↔ server. The +// Hub tracks one active connection per host id; subsequent connections +// from the same host evict the prior one (last-write-wins). +package ws + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/coder/websocket" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// Hub owns the live agent connections and routes messages. +type Hub struct { + mu sync.RWMutex + conns map[string]*Conn // hostID → conn +} + +// NewHub returns an empty hub. +func NewHub() *Hub { + return &Hub{conns: make(map[string]*Conn)} +} + +// Conn is one agent WS connection. Send is safe for concurrent use; +// Read is single-reader (the connection's run loop). +type Conn struct { + HostID string + c *websocket.Conn + + writeMu sync.Mutex +} + +// Register installs c as the canonical connection for hostID. Any +// previous connection for that host is closed. +func (h *Hub) Register(hostID string, c *Conn) { + h.mu.Lock() + if prev, ok := h.conns[hostID]; ok { + // Best-effort close — a stuck old socket shouldn't block new one. + go func(old *Conn) { + _ = old.c.Close(websocket.StatusPolicyViolation, "superseded") + }(prev) + } + h.conns[hostID] = c + h.mu.Unlock() +} + +// Unregister removes c iff it is still the canonical conn (a race +// where a newer conn already replaced it must not unregister it). +func (h *Hub) Unregister(hostID string, c *Conn) { + h.mu.Lock() + if cur, ok := h.conns[hostID]; ok && cur == c { + delete(h.conns, hostID) + } + h.mu.Unlock() +} + +// Send delivers an envelope to the host if connected. Returns an error +// if the host is offline; caller may queue the message for later. +func (h *Hub) Send(ctx context.Context, hostID string, env api.Envelope) error { + h.mu.RLock() + c, ok := h.conns[hostID] + h.mu.RUnlock() + if !ok { + return fmt.Errorf("ws: host %q is offline", hostID) + } + return c.Send(ctx, env) +} + +// Connected reports whether hostID has an active connection. +func (h *Hub) Connected(hostID string) bool { + h.mu.RLock() + _, ok := h.conns[hostID] + h.mu.RUnlock() + return ok +} + +// ----- Conn methods -------------------------------------------------- + +// NewConn wraps a freshly-accepted websocket for a given hostID. +func NewConn(hostID string, c *websocket.Conn) *Conn { + return &Conn{HostID: hostID, c: c} +} + +// Send writes an envelope as a JSON text message. Concurrent calls +// are serialised; the underlying socket is not safe for parallel +// writers. +func (c *Conn) Send(ctx context.Context, env api.Envelope) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + raw, err := json.Marshal(env) + if err != nil { + return fmt.Errorf("ws: marshal envelope: %w", err) + } + return c.c.Write(ctx, websocket.MessageText, raw) +} + +// SendError writes an error envelope and closes the socket. Used by +// the hello handshake when an agent is rejected. +func (c *Conn) SendError(ctx context.Context, code api.ErrorCode, msg, helpURL string) { + env, err := api.Marshal(api.MsgError, "", api.ErrorPayload{ + Code: code, Message: msg, HelpURL: helpURL, + }) + if err == nil { + writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + _ = c.Send(writeCtx, env) + } + _ = c.c.Close(websocket.StatusPolicyViolation, string(code)) +} + +// Close shuts the socket down with a normal-closure status code. +func (c *Conn) Close() error { + return c.c.Close(websocket.StatusNormalClosure, "") +} + +// Read pulls the next JSON envelope off the wire. The caller's +// context controls cancellation and timeouts (e.g. read deadlines). +func (c *Conn) Read(ctx context.Context) (api.Envelope, error) { + mt, raw, err := c.c.Read(ctx) + if err != nil { + return api.Envelope{}, err + } + if mt != websocket.MessageText { + return api.Envelope{}, errors.New("ws: expected text frame") + } + var env api.Envelope + if err := json.Unmarshal(raw, &env); err != nil { + return api.Envelope{}, fmt.Errorf("ws: unmarshal envelope: %w", err) + } + return env, nil +} + +// ----- helpers ------------------------------------------------------- + +// LogValue emits a slog-friendly representation of a Conn. +func (c *Conn) LogValue() slog.Value { + return slog.GroupValue(slog.String("host_id", c.HostID)) +} diff --git a/internal/server/ws/hub_test.go b/internal/server/ws/hub_test.go new file mode 100644 index 0000000..ed8b04f --- /dev/null +++ b/internal/server/ws/hub_test.go @@ -0,0 +1,181 @@ +package ws + +import ( + "context" + "encoding/json" + stdhttp "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/coder/websocket" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// setupTestHub spins up a Server that exposes only /ws/agent against +// a fresh sqlite store with one pre-enrolled host. Returns the URL, +// the agent's bearer token, and the host ID. +func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) { + t.Helper() + dir := t.TempDir() + var err error + st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + hub = NewHub() + mux := stdhttp.NewServeMux() + mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st})) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + // Pre-enroll a host directly via store (skipping HTTP). + hostID = "01HJ8K70000000000000000000" + token, _ = auth.NewToken() + now := time.Now().UTC() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: "h1", OS: "linux", Arch: "amd64", + EnrolledAt: now, + }, auth.HashToken(token), ""); err != nil { + t.Fatalf("enroll: %v", err) + } + url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent" + return +} + +func TestWSHelloAndHeartbeat(t *testing.T) { + t.Parallel() + url, token, hostID, st, hub := setupTestHub(t) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, + }) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.CloseNow() + + // Send hello. + hello := api.HelloPayload{ + ProtocolVersion: api.CurrentProtocolVersion, + AgentVersion: "0.1.0", + ResticVersion: "0.17.1", + Hostname: "h1", + OS: api.OSLinux, + Arch: api.ArchAmd64, + } + env, _ := api.Marshal(api.MsgHello, "", hello) + raw, _ := json.Marshal(env) + if err := c.Write(ctx, websocket.MessageText, raw); err != nil { + t.Fatalf("write hello: %v", err) + } + + // Wait for the server to register us (registration happens after + // the hello-handler returns; give it up to 1s). + deadline := time.Now().Add(time.Second) + for !hub.Connected(hostID) && time.Now().Before(deadline) { + time.Sleep(20 * time.Millisecond) + } + if !hub.Connected(hostID) { + t.Fatal("host did not register on hub after hello") + } + + // Verify host row was marked online + has populated metadata. + h, err := st.GetHost(context.Background(), hostID) + if err != nil { + t.Fatalf("get host: %v", err) + } + if h.Status != "online" || h.AgentVersion != "0.1.0" { + t.Errorf("host after hello: %+v", h) + } + + // Send a heartbeat — server should touch last_seen. + hb := api.HeartbeatPayload{SentAt: time.Now().UTC()} + env, _ = api.Marshal(api.MsgHeartbeat, "", hb) + raw, _ = json.Marshal(env) + preTouch := h.LastSeenAt + _ = c.Write(ctx, websocket.MessageText, raw) + + // Wait briefly for server to process. + deadline = time.Now().Add(time.Second) + for time.Now().Before(deadline) { + h2, _ := st.GetHost(context.Background(), hostID) + if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Error("heartbeat did not update last_seen_at") +} + +func TestWSRejectsOldProtocol(t *testing.T) { + t.Parallel() + url, token, _, _, _ := setupTestHub(t) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, + }) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.CloseNow() + + hello := api.HelloPayload{ProtocolVersion: 0} // below minimum + env, _ := api.Marshal(api.MsgHello, "", hello) + raw, _ := json.Marshal(env) + _ = c.Write(ctx, websocket.MessageText, raw) + + // Server should send an error envelope, then close. + mt, body, err := c.Read(ctx) + if err != nil { + t.Fatalf("read: %v", err) + } + if mt != websocket.MessageText { + t.Fatalf("frame type: %v", mt) + } + var got api.Envelope + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Type != api.MsgError { + t.Errorf("expected error envelope, got %q", got.Type) + } + var ep api.ErrorPayload + _ = got.UnmarshalPayload(&ep) + if ep.Code != api.ErrProtocolTooOld { + t.Errorf("error code: %q", ep.Code) + } +} + +func TestWSRejectsBadToken(t *testing.T) { + t.Parallel() + url, _, _, _, _ := setupTestHub(t) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}}, + }) + if err == nil { + t.Fatal("dial should fail") + } + if res == nil || res.StatusCode != stdhttp.StatusUnauthorized { + if res != nil { + t.Errorf("status: %d", res.StatusCode) + } + } +} diff --git a/internal/store/hosts.go b/internal/store/hosts.go new file mode 100644 index 0000000..1f1e113 --- /dev/null +++ b/internal/store/hosts.go @@ -0,0 +1,205 @@ +package store + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" +) + +// CreateHost inserts a new host row. Used by the enrollment flow. +// The caller has already minted the host id and hashed the agent +// bearer token. +func (s *Store) CreateHost(ctx context.Context, h Host, agentTokenHash, certPinSHA256 string) error { + tags, err := json.Marshal(h.Tags) + if err != nil { + return fmt.Errorf("store: marshal tags: %w", err) + } + _, err = s.db.ExecContext(ctx, + `INSERT INTO hosts ( + id, name, os, arch, agent_version, restic_version, protocol_version, + enrolled_at, status, tags, + agent_token_hash, cert_pin_sha256 + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'offline', ?, ?, ?)`, + h.ID, h.Name, h.OS, h.Arch, + h.AgentVersion, h.ResticVersion, h.ProtocolVersion, + h.EnrolledAt.UTC().Format(time.RFC3339Nano), + string(tags), + agentTokenHash, certPinSHA256) + if err != nil { + return fmt.Errorf("store: create host: %w", err) + } + return nil +} + +// LookupHostByAgentToken resolves a hashed agent bearer token to the +// host it belongs to. Returns ErrNotFound on miss. +func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*Host, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, + enrolled_at, last_seen_at, status, repo_id, tags, + current_job_id, last_backup_at, last_backup_status, + repo_size_bytes, snapshot_count, open_alert_count, + applied_schedule_version + FROM hosts WHERE agent_token_hash = ?`, + tokenHash) + return scanHost(row) +} + +// GetHost returns a host by ID. Returns ErrNotFound on miss. +func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, + enrolled_at, last_seen_at, status, repo_id, tags, + current_job_id, last_backup_at, last_backup_status, + repo_size_bytes, snapshot_count, open_alert_count, + applied_schedule_version + FROM hosts WHERE id = ?`, id) + return scanHost(row) +} + +// MarkHostHello updates the host row with metadata received in the +// agent's hello message and flips status to 'online'. +func (s *Store) MarkHostHello(ctx context.Context, id string, agentVersion, resticVersion string, protoVersion int, when time.Time) error { + _, err := s.db.ExecContext(ctx, + `UPDATE hosts + SET agent_version = ?, restic_version = ?, protocol_version = ?, + last_seen_at = ?, status = 'online' + WHERE id = ?`, + agentVersion, resticVersion, protoVersion, + when.UTC().Format(time.RFC3339Nano), id) + if err != nil { + return fmt.Errorf("store: mark hello: %w", err) + } + return nil +} + +// TouchHost updates last_seen_at on heartbeat, leaving status alone if +// already online (the offline-marker is a separate sweep). +func (s *Store) TouchHost(ctx context.Context, id string, when time.Time) error { + _, err := s.db.ExecContext(ctx, + `UPDATE hosts + SET last_seen_at = ?, + status = CASE WHEN status = 'offline' THEN 'online' ELSE status END + WHERE id = ?`, + when.UTC().Format(time.RFC3339Nano), id) + if err != nil { + return fmt.Errorf("store: touch host: %w", err) + } + return nil +} + +// MarkHostsOfflineStale flips any host that hasn't been seen since +// before `cutoff` from 'online' to 'offline'. Returns the number of +// rows affected so the caller can log non-zero events. +func (s *Store) MarkHostsOfflineStale(ctx context.Context, cutoff time.Time) (int64, error) { + res, err := s.db.ExecContext(ctx, + `UPDATE hosts + SET status = 'offline' + WHERE status = 'online' + AND (last_seen_at IS NULL OR last_seen_at < ?)`, + cutoff.UTC().Format(time.RFC3339Nano)) + if err != nil { + return 0, fmt.Errorf("store: mark offline: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} + +// ListHosts returns every host. Phase 1 callers fit a small fleet in +// memory; pagination lands when it matters. +func (s *Store) ListHosts(ctx context.Context) ([]Host, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, + enrolled_at, last_seen_at, status, repo_id, tags, + current_job_id, last_backup_at, last_backup_status, + repo_size_bytes, snapshot_count, open_alert_count, + applied_schedule_version + FROM hosts ORDER BY name`) + if err != nil { + return nil, fmt.Errorf("store: list hosts: %w", err) + } + defer rows.Close() + var out []Host + for rows.Next() { + h, err := scanHostRow(rows) + if err != nil { + return nil, err + } + out = append(out, *h) + } + return out, rows.Err() +} + +// ----- scan helpers -------------------------------------------------- + +type hostScanner interface { + Scan(dest ...any) error +} + +func scanHost(row *sql.Row) (*Host, error) { + h, err := scanHostRow(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return h, err +} + +func scanHostRow(s hostScanner) (*Host, error) { + var h Host + var ( + lastSeen, lastBackupAt sql.NullString + repoID, currentJob, lastBkSt sql.NullString + enrolled string + tags string + ) + err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch, + &h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion, + &enrolled, &lastSeen, &h.Status, &repoID, &tags, + ¤tJob, &lastBackupAt, &lastBkSt, + &h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount, + &h.AppliedScheduleVersion) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("store: scan host: %w", err) + } + t, err := time.Parse(time.RFC3339Nano, enrolled) + if err != nil { + return nil, fmt.Errorf("store: parse enrolled_at: %w", err) + } + h.EnrolledAt = t + if lastSeen.Valid { + t, err := time.Parse(time.RFC3339Nano, lastSeen.String) + if err != nil { + return nil, fmt.Errorf("store: parse last_seen_at: %w", err) + } + h.LastSeenAt = &t + } + if lastBackupAt.Valid { + t, err := time.Parse(time.RFC3339Nano, lastBackupAt.String) + if err != nil { + return nil, fmt.Errorf("store: parse last_backup_at: %w", err) + } + h.LastBackupAt = &t + } + if repoID.Valid { + s := repoID.String + h.RepoID = &s + } + if currentJob.Valid { + s := currentJob.String + h.CurrentJobID = &s + } + if lastBkSt.Valid { + s := lastBkSt.String + h.LastBackupStatus = &s + } + if tags != "" { + _ = json.Unmarshal([]byte(tags), &h.Tags) + } + return &h, nil +} diff --git a/internal/store/migrations/0001_initial.sql b/internal/store/migrations/0001_initial.sql index 746e906..8eb555c 100644 --- a/internal/store/migrations/0001_initial.sql +++ b/internal/store/migrations/0001_initial.sql @@ -92,12 +92,15 @@ CREATE INDEX hosts_status ON hosts(status); CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at); -- Pending one-time enrollment tokens (TTL'd, single-use). +-- consumed_host is audit-only (no FK on purpose: we burn the token +-- before the host row exists, and we want this trail to survive a +-- later host deletion). CREATE TABLE enrollment_tokens ( - token_hash TEXT PRIMARY KEY, -- argon2id of token + token_hash TEXT PRIMARY KEY, created_at TEXT NOT NULL, expires_at TEXT NOT NULL, consumed_at TEXT, - consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL + consumed_host TEXT ); CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at);