diff --git a/cmd/server/main.go b/cmd/server/main.go index 8d52bb8..dcd0d38 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -17,6 +17,7 @@ import ( "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/notification" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/fleetupdate" rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/maintenance" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc" @@ -91,6 +92,7 @@ func run() error { notifHub := notification.NewHub(st, aead, cfg.BaseURL) alertEngine := alert.NewEngine(st, notifHub) + updateWatcher := ws.NewUpdateWatcher(st, alertEngine) renderer, err := ui.New() if err != nil { @@ -116,6 +118,7 @@ func run() error { JobHub: jobHub, AlertEngine: alertEngine, NotificationHub: notifHub, + UpdateWatcher: updateWatcher, UI: renderer, Version: version, OIDC: oidcClient, @@ -147,10 +150,17 @@ func run() error { srv := rmhttp.New(deps) + // Fleet-update worker — built after the HTTP server because the + // dispatcher delegates back into srv.DispatchHostUpdate. + fleetWorker := fleetupdate.NewWorker(st, hub, + &serverDispatcher{srv: srv}, alertEngine) + srv.SetFleetWorker(fleetWorker) + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() go alertEngine.Run(ctx) + go updateWatcher.Run(ctx) errCh := make(chan error, 1) go func() { @@ -243,3 +253,12 @@ func run() error { } return nil } + +// serverDispatcher adapts the http.Server's DispatchHostUpdate method +// to the fleetupdate.Dispatcher interface. Lives in main so the +// http and fleetupdate packages don't need to know about each other. +type serverDispatcher struct{ srv *rmhttp.Server } + +func (d *serverDispatcher) DispatchUpdate(ctx context.Context, hostID, actorUserID string) (string, string, error) { + return d.srv.DispatchHostUpdate(ctx, hostID, actorUserID) +} diff --git a/internal/alert/update_alerts.go b/internal/alert/update_alerts.go new file mode 100644 index 0000000..9a7da6e --- /dev/null +++ b/internal/alert/update_alerts.go @@ -0,0 +1,63 @@ +package alert + +import ( + "context" + "fmt" + "log/slog" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/notification" +) + +// Alert-kind constants for P6 self-update flows. +const ( + // KindUpdateFailed is raised when an agent fails to come back with + // the expected version after a command.update dispatch (timeout or + // version-mismatch). Resolved by a subsequent matching hello. + KindUpdateFailed = "update_failed" + + // KindFleetUpdateHalted is raised when the fleet-update worker + // stops mid-run because a host failed to update or went offline. + // Host-less alert (system-scoped). Manually resolved by an admin. + KindFleetUpdateHalted = "fleet_update_halted" +) + +// RaiseUpdateFailed records a per-host update failure. dedupKey is the +// hostID so a re-dispatch on the same host touches the existing alert +// rather than spawning a duplicate. +func (e *Engine) RaiseUpdateFailed(ctx context.Context, hostID, jobID, reason string, when time.Time) { + msg := fmt.Sprintf("Agent update failed (job %s): %s", jobID, reason) + e.raiseAndNotify(ctx, hostID, KindUpdateFailed, hostID, "warning", msg, when) +} + +// ResolveUpdateFailed clears any open update_failed alert for hostID. +// Called from the WS hello path when the agent reconnects with the +// target version. +func (e *Engine) ResolveUpdateFailed(ctx context.Context, hostID string, when time.Time) { + e.resolveAndNotify(ctx, hostID, KindUpdateFailed, hostID, when) +} + +// RaiseFleetUpdateHalted is host-less — the fleet update is a +// system-level concept. We persist it via the dedicated host-less +// alert path so the alerts table's host_id column carries NULL. +func (e *Engine) RaiseFleetUpdateHalted(ctx context.Context, fleetUpdateID, reason string, when time.Time) { + msg := fmt.Sprintf("Fleet update %s halted: %s", fleetUpdateID, reason) + id, didRaise, err := e.store.RaiseOrTouchSystem(ctx, KindFleetUpdateHalted, fleetUpdateID, "warning", msg, when) + if err != nil { + slog.Warn("alert: raise fleet_update_halted", "fu_id", fleetUpdateID, "err", err) + return + } + if !didRaise { + return + } + go e.hub.Dispatch(ctx, notification.Payload{ + Event: notification.EventRaised, + AlertID: id, + Severity: "warning", + Kind: KindFleetUpdateHalted, + HostID: "", + HostName: "", + Message: msg, + RaisedAt: when, + }) +} diff --git a/internal/server/fleetupdate/worker.go b/internal/server/fleetupdate/worker.go new file mode 100644 index 0000000..1442832 --- /dev/null +++ b/internal/server/fleetupdate/worker.go @@ -0,0 +1,221 @@ +// Package fleetupdate drives a rolling, sequential agent self-update +// over a list of hosts. One worker goroutine per Start() call (gated +// at the store layer to at-most-one-running-fleet-update). +package fleetupdate + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// Hub is the slim "is this host connected?" surface. +type Hub interface { + Connected(hostID string) bool +} + +// Dispatcher sends one command.update envelope. The implementer also +// creates the jobs row, writes audit, and registers with the update +// watcher. Pre-checks are the dispatcher's responsibility — the worker +// passes through whatever error it returns. +type Dispatcher interface { + DispatchUpdate(ctx context.Context, hostID string, actorUserID string) (jobID string, code string, err error) +} + +// AlertRaiser is the slim view of the alert engine's host-less raise +// path. Used to emit fleet_update_halted on first failure. +type AlertRaiser interface { + RaiseFleetUpdateHalted(ctx context.Context, fleetUpdateID, reason string, when time.Time) +} + +// Worker is the long-lived fleet-update orchestrator. There is at most +// one *running* fleet update at a time (enforced by the store). +type Worker struct { + store *store.Store + hub Hub + disp Dispatcher + alerts AlertRaiser + + // targetVersion is the version every dispatched agent is expected + // to come back with. Captured at Start time to avoid drift. + targetVersion string + + // pollPeriod controls the cadence at which the worker re-reads the + // host row to check for the version transition. Exposed for tests. + pollPeriod time.Duration + // hostTimeout bounds how long the worker waits for one host to + // reach the target version before halting. + hostTimeout time.Duration +} + +// NewWorker builds an unstarted worker. targetVersion is set on each +// Start call; the values here are defaults. +func NewWorker(st *store.Store, hub Hub, disp Dispatcher, alerts AlertRaiser) *Worker { + return &Worker{ + store: st, + hub: hub, + disp: disp, + alerts: alerts, + pollPeriod: 1 * time.Second, + hostTimeout: 95 * time.Second, + } +} + +// Start creates the parent + child rows, then spawns the per-host +// worker goroutine. Returns the new fleet_update_id on success. +// store.ErrFleetUpdateRunning bubbles up unchanged. +func (w *Worker) Start(ctx context.Context, userID, targetVersion string, hostIDs []string) (string, error) { + if userID == "" || targetVersion == "" { + return "", errors.New("fleetupdate: userID and targetVersion required") + } + if len(hostIDs) == 0 { + return "", errors.New("fleetupdate: at least one host required") + } + fuID := ulid.Make().String() + now := time.Now().UTC() + if err := w.store.CreateFleetUpdate(ctx, store.FleetUpdate{ + ID: fuID, + StartedAt: now, + StartedByUserID: userID, + TargetVersion: targetVersion, + Status: "running", + }, hostIDs); err != nil { + return "", err + } + + // The goroutine outlives the request that started it; carry a + // detached context so an HTTP-handler ctx cancel doesn't abort + // the long roll. + bg := context.WithoutCancel(ctx) + go w.run(bg, fuID, userID, targetVersion) + return fuID, nil +} + +// Cancel marks the fleet update cancelled. The running goroutine +// observes the new status on its next pre-check and exits without +// dispatching further hosts. The currently-dispatched job is left to +// finish on its own — cancelling agent-side is out of scope for v1. +func (w *Worker) Cancel(ctx context.Context, fuID string) error { + return w.store.CancelFleetUpdate(ctx, fuID, time.Now().UTC()) +} + +// run is the per-host loop. Halts on first failure; emits one alert +// on transition. +func (w *Worker) run(ctx context.Context, fuID, userID, targetVersion string) { + w.targetVersion = targetVersion + + for { + // Check the parent row's status — picks up Cancel. + fu, err := w.store.ActiveFleetUpdate(ctx) + if err != nil { + slog.Warn("fleetupdate: read active", "fu_id", fuID, "err", err) + return + } + if fu == nil || fu.ID != fuID { + // Cancelled, halted, or completed externally. Done. + return + } + + pending, err := w.store.ListPendingFleetUpdateHosts(ctx, fuID) + if err != nil { + slog.Warn("fleetupdate: list pending", "fu_id", fuID, "err", err) + return + } + if len(pending) == 0 { + now := time.Now().UTC() + if err := w.store.CompleteFleetUpdate(ctx, fuID, now); err != nil { + slog.Warn("fleetupdate: complete", "fu_id", fuID, "err", err) + } + return + } + + next := pending[0] + w.processHost(ctx, fuID, userID, next) + } +} + +// processHost handles one host slot. Marks it skipped, succeeded, or +// failed (and halts the fleet on failure). +func (w *Worker) processHost(ctx context.Context, fuID, userID string, slot store.FleetUpdateHost) { + hostID := slot.HostID + _ = w.store.SetFleetUpdateCurrentHost(ctx, fuID, hostID) + + // Pre-flight: re-read the host. The dispatch path repeats most of + // these checks but doing them up-front lets us emit the right + // per-host status (skipped vs failed) without consuming a job row. + host, err := w.store.GetHost(ctx, hostID) + if err != nil || host == nil { + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "skipped", "host not found", "") + return + } + if host.AgentVersion != "" && host.AgentVersion == w.targetVersion { + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "skipped", "already at target version", "") + return + } + if !w.hub.Connected(hostID) { + reason := fmt.Sprintf("host went offline: %s", hostID) + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, "") + w.halt(ctx, fuID, reason) + return + } + + // Dispatch. + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "running", "", "") + jobID, code, err := w.disp.DispatchUpdate(ctx, hostID, userID) + if err != nil || code != "" { + reason := dispatchErrorReason(code, err) + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, jobID) + w.halt(ctx, fuID, reason) + return + } + + // Poll until the host's recorded agent_version matches target, or + // timeout. + deadline := time.Now().Add(w.hostTimeout) + for time.Now().Before(deadline) { + // Honour cancellation between polls. + fu, err := w.store.ActiveFleetUpdate(ctx) + if err == nil && (fu == nil || fu.ID != fuID) { + // Cancelled mid-host; leave the slot in 'running' for the + // admin to inspect. No further dispatches. + return + } + time.Sleep(w.pollPeriod) + h, err := w.store.GetHost(ctx, hostID) + if err == nil && h != nil && h.AgentVersion == w.targetVersion { + if err := w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "succeeded", "", jobID); err != nil { + slog.Warn("fleetupdate: set succeeded", "fu_id", fuID, "host_id", hostID, "err", err) + } + return + } + } + reason := fmt.Sprintf("timeout waiting for %s to reach %s", hostID, w.targetVersion) + _ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, jobID) + w.halt(ctx, fuID, reason) +} + +func (w *Worker) halt(ctx context.Context, fuID, reason string) { + now := time.Now().UTC() + if err := w.store.HaltFleetUpdate(ctx, fuID, reason, now); err != nil { + slog.Warn("fleetupdate: halt", "fu_id", fuID, "err", err) + } + if w.alerts != nil { + w.alerts.RaiseFleetUpdateHalted(ctx, fuID, reason, now) + } +} + +func dispatchErrorReason(code string, err error) string { + if code != "" { + return "dispatch failed: " + code + } + if err != nil { + return err.Error() + } + return "dispatch failed" +} diff --git a/internal/server/fleetupdate/worker_test.go b/internal/server/fleetupdate/worker_test.go new file mode 100644 index 0000000..c1cdac1 --- /dev/null +++ b/internal/server/fleetupdate/worker_test.go @@ -0,0 +1,344 @@ +package fleetupdate + +import ( + "context" + "errors" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +type fakeHub struct { + mu sync.Mutex + online map[string]bool +} + +func (f *fakeHub) Connected(hostID string) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.online[hostID] +} + +type fakeDispatcher struct { + mu sync.Mutex + calls []string // host IDs + // after dispatch, set the host's agent_version to this on the + // store so the worker observes the version transition. + st *store.Store + target string + delayMS int + failOnHost map[string]string // host → error code +} + +func (f *fakeDispatcher) DispatchUpdate(ctx context.Context, hostID, _ string) (string, string, error) { + f.mu.Lock() + f.calls = append(f.calls, hostID) + if code, ok := f.failOnHost[hostID]; ok { + f.mu.Unlock() + return "", code, nil + } + st := f.st + target := f.target + delay := f.delayMS + f.mu.Unlock() + + jobID := ulid.Make().String() + if st != nil { + _ = st.CreateJob(context.Background(), store.Job{ + ID: jobID, HostID: hostID, Kind: "update", + ActorKind: "user", CreatedAt: time.Now().UTC(), + }) + } + if st != nil && target != "" { + go func() { + if delay > 0 { + time.Sleep(time.Duration(delay) * time.Millisecond) + } + _ = st.MarkHostHello(context.Background(), hostID, target, "0.17", api.CurrentProtocolVersion, time.Now().UTC()) + }() + } + return jobID, "", nil +} + +type recAlert struct { + mu sync.Mutex + reasons []string +} + +func (r *recAlert) RaiseFleetUpdateHalted(_ context.Context, _ string, reason string, _ time.Time) { + r.mu.Lock() + r.reasons = append(r.reasons, reason) + r.mu.Unlock() +} + +func openStore(t *testing.T) *store.Store { + t.Helper() + dir := t.TempDir() + st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + return st +} + +func mustCreateAdmin(t *testing.T, st *store.Store) string { + t.Helper() + uid := ulid.Make().String() + if err := st.CreateUser(context.Background(), store.User{ + ID: uid, Username: "u-" + uid[:6], + PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("user: %v", err) + } + return uid +} + +func mustCreateHost(t *testing.T, st *store.Store, name, version string) string { + t.Helper() + hostID := ulid.Make().String() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: name, OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, "deadbeef-"+hostID, ""); err != nil { + t.Fatalf("host: %v", err) + } + if version != "" { + if err := st.MarkHostHello(context.Background(), hostID, version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { + t.Fatalf("hello: %v", err) + } + } + return hostID +} + +func waitForStatus(t *testing.T, st *store.Store, fuID, want string, timeout time.Duration) *store.FleetUpdate { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + fu, _, err := st.GetFleetUpdate(context.Background(), fuID) + if err == nil && fu != nil && fu.Status == want { + return fu + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("status never reached %q", want) + return nil +} + +func TestWorkerTwoHostsBothSucceed(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v0") + h2 := mustCreateHost(t, st, "h2", "v0") + + hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} + disp := &fakeDispatcher{st: st, target: "v2", delayMS: 30} + alerts := &recAlert{} + w := NewWorker(st, hub, disp, alerts) + w.pollPeriod = 20 * time.Millisecond + w.hostTimeout = 2 * time.Second + + fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) + if err != nil { + t.Fatalf("start: %v", err) + } + waitForStatus(t, st, fuID, "completed", 5*time.Second) + _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) + for _, h := range hosts { + if h.Status != "succeeded" { + t.Errorf("host %s status %q want succeeded", h.HostID, h.Status) + } + } + if n := len(alerts.reasons); n != 0 { + t.Errorf("unexpected halt alert: %v", alerts.reasons) + } +} + +func TestWorkerSecondHostTimesOutHalts(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v0") + h2 := mustCreateHost(t, st, "h2", "v0") + h3 := mustCreateHost(t, st, "h3", "v0") + + hub := &fakeHub{online: map[string]bool{h1: true, h2: true, h3: true}} + // h1 dispatches normally (transitions to v2). h2 dispatch returns + // success but never transitions. + disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20, failOnHost: map[string]string{ + h2: "", // not a code-failure; simulate by clearing target on this disp run + }} + // Actually: drop h2 from the auto-transition by faking with a + // per-host store setter. Easiest: subclass via a wrapper. + _ = disp + customDisp := &perHostDispatcher{base: disp, st: st, target: "v2", noTransition: map[string]bool{h2: true}} + + alerts := &recAlert{} + w := NewWorker(st, hub, customDisp, alerts) + w.pollPeriod = 20 * time.Millisecond + w.hostTimeout = 200 * time.Millisecond + + fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2, h3}) + if err != nil { + t.Fatalf("start: %v", err) + } + waitForStatus(t, st, fuID, "halted", 3*time.Second) + _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) + gotStatus := map[string]string{} + for _, h := range hosts { + gotStatus[h.HostID] = h.Status + } + if gotStatus[h1] != "succeeded" { + t.Errorf("h1: %q", gotStatus[h1]) + } + if gotStatus[h2] != "failed" { + t.Errorf("h2: %q", gotStatus[h2]) + } + if gotStatus[h3] != "pending" { + t.Errorf("h3: %q", gotStatus[h3]) + } + alerts.mu.Lock() + defer alerts.mu.Unlock() + if len(alerts.reasons) != 1 { + t.Errorf("alert reasons: %v", alerts.reasons) + } +} + +// perHostDispatcher lets a test omit the auto-transition for selected +// hosts so we can simulate timeout. +type perHostDispatcher struct { + mu sync.Mutex + base *fakeDispatcher + st *store.Store + target string + noTransition map[string]bool +} + +func (p *perHostDispatcher) DispatchUpdate(_ context.Context, hostID, _ string) (string, string, error) { + p.mu.Lock() + skip := p.noTransition[hostID] + p.mu.Unlock() + jobID := ulid.Make().String() + _ = p.st.CreateJob(context.Background(), store.Job{ + ID: jobID, HostID: hostID, Kind: "update", + ActorKind: "user", CreatedAt: time.Now().UTC(), + }) + if !skip { + go func() { + time.Sleep(20 * time.Millisecond) + _ = p.st.MarkHostHello(context.Background(), hostID, p.target, "0.17", api.CurrentProtocolVersion, time.Now().UTC()) + }() + } + return jobID, "", nil +} + +func TestWorkerHostOfflineHalts(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v0") + h2 := mustCreateHost(t, st, "h2", "v0") + hub := &fakeHub{online: map[string]bool{h1: false, h2: true}} + disp := &fakeDispatcher{st: st, target: "v2"} + alerts := &recAlert{} + w := NewWorker(st, hub, disp, alerts) + w.pollPeriod = 20 * time.Millisecond + w.hostTimeout = 500 * time.Millisecond + + fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) + if err != nil { + t.Fatalf("start: %v", err) + } + waitForStatus(t, st, fuID, "halted", 2*time.Second) + _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) + if hosts[0].Status != "failed" { + t.Errorf("h1 status: %q", hosts[0].Status) + } + if hosts[1].Status != "pending" { + t.Errorf("h2 status: %q", hosts[1].Status) + } +} + +func TestWorkerAlreadyAtTargetSkipped(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v2") + h2 := mustCreateHost(t, st, "h2", "v0") + hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} + disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20} + alerts := &recAlert{} + w := NewWorker(st, hub, disp, alerts) + w.pollPeriod = 20 * time.Millisecond + w.hostTimeout = 2 * time.Second + + fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) + if err != nil { + t.Fatalf("start: %v", err) + } + waitForStatus(t, st, fuID, "completed", 4*time.Second) + _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) + want := map[string]string{h1: "skipped", h2: "succeeded"} + for _, h := range hosts { + if h.Status != want[h.HostID] { + t.Errorf("host %s: got %q want %q", h.HostID, h.Status, want[h.HostID]) + } + } +} + +func TestWorkerCancelMidRun(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v0") + h2 := mustCreateHost(t, st, "h2", "v0") + hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} + // h1's transition is delayed long enough that we can cancel + // before it lands; h2 should never be touched. + disp := &fakeDispatcher{st: st, target: "v2", delayMS: 500} + alerts := &recAlert{} + w := NewWorker(st, hub, disp, alerts) + w.pollPeriod = 50 * time.Millisecond + w.hostTimeout = 5 * time.Second + + fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) + if err != nil { + t.Fatalf("start: %v", err) + } + // Give the worker a moment to dispatch h1. + time.Sleep(100 * time.Millisecond) + if err := w.Cancel(context.Background(), fuID); err != nil { + t.Fatalf("cancel: %v", err) + } + waitForStatus(t, st, fuID, "cancelled", 2*time.Second) + + // h2 should never be dispatched. + disp.mu.Lock() + defer disp.mu.Unlock() + for _, c := range disp.calls { + if c == h2 { + t.Errorf("h2 dispatched after cancel") + } + } +} + +func TestWorkerStartWhileActiveErrors(t *testing.T) { + st := openStore(t) + uid := mustCreateAdmin(t, st) + h1 := mustCreateHost(t, st, "h1", "v0") + h2 := mustCreateHost(t, st, "h2", "v0") + hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} + disp := &fakeDispatcher{st: st, target: "v2", delayMS: 5_000} + w := NewWorker(st, hub, disp, &recAlert{}) + w.pollPeriod = 50 * time.Millisecond + w.hostTimeout = 2 * time.Second + if _, err := w.Start(context.Background(), uid, "v2", []string{h1}); err != nil { + t.Fatalf("first start: %v", err) + } + _, err := w.Start(context.Background(), uid, "v2", []string{h2}) + if !errors.Is(err, store.ErrFleetUpdateRunning) { + t.Fatalf("err: %v want ErrFleetUpdateRunning", err) + } +} diff --git a/internal/server/http/host_update.go b/internal/server/http/host_update.go new file mode 100644 index 0000000..b1a2033 --- /dev/null +++ b/internal/server/http/host_update.go @@ -0,0 +1,217 @@ +package http + +import ( + "context" + "encoding/json" + stdhttp "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" + "gitea.dcglab.co.uk/steve/restic-manager/internal/version" +) + +// UpdateWatcher is the slim view of the ws.updateWatcher this package +// uses for tracking in-flight update dispatches. Defined as an +// interface so a test can inject a stub. +type UpdateWatcher interface { + Track(jobID, hostID string) +} + +// FleetWorker is the slim view of the fleetupdate.Worker this package +// uses. Kept here for forward compatibility with P6-15 — the host +// update endpoint itself does not use it. +type FleetWorker interface { + Start(ctx context.Context, userID, targetVersion string, hostIDs []string) (string, error) + Cancel(ctx context.Context, fleetUpdateID string) error +} + +// dispatchHostUpdateResult communicates structured outcomes from the +// shared dispatch path so both the HTTP handler and the fleet worker +// can format errors in their own idiom. +type dispatchHostUpdateResult struct { + JobID string + Code string // "" on success + Status int // HTTP status the JSON handler should use on error + Msg string // human-readable detail (optional) +} + +// dispatchHostUpdate is the shared "send command.update to one host" +// path. It performs every pre-check (host exists, online, version +// mismatch, no in-flight update) and on success creates the jobs row, +// audits, dispatches the WS envelope, and tracks the watcher entry. +// +// Pre-checks are returned as structured codes rather than HTTP errors +// so the fleet worker can map them onto its own per-host status enum +// without parsing strings. +func (s *Server) dispatchHostUpdate(ctx context.Context, hostID string, actorKind string, actorID *string) dispatchHostUpdateResult { + host, err := s.deps.Store.GetHost(ctx, hostID) + if err != nil || host == nil { + return dispatchHostUpdateResult{Code: "host_not_found", Status: stdhttp.StatusNotFound} + } + if !s.deps.Hub.Connected(host.ID) { + return dispatchHostUpdateResult{ + Code: "host_offline", Status: stdhttp.StatusConflict, + Msg: "agent is not currently connected", + } + } + if host.AgentVersion != "" && host.AgentVersion == version.Version { + return dispatchHostUpdateResult{ + Code: "already_up_to_date", Status: stdhttp.StatusConflict, + Msg: "agent already running version " + version.Version, + } + } + existing, err := s.deps.Store.RunningUpdateJobForHost(ctx, hostID) + if err != nil { + return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()} + } + if existing != "" { + return dispatchHostUpdateResult{ + Code: "update_in_progress", Status: stdhttp.StatusConflict, + Msg: "an update job is already in flight for this host", + JobID: existing, + } + } + + jobID := ulid.Make().String() + now := time.Now().UTC() + if err := s.deps.Store.CreateJob(ctx, store.Job{ + ID: jobID, HostID: hostID, Kind: "update", + ActorKind: actorKind, ActorID: actorID, + CreatedAt: now, + }); err != nil { + return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()} + } + env, err := api.Marshal(api.MsgCommandUpdate, ulid.Make().String(), api.CommandUpdatePayload{ + JobID: jobID, + }) + if err != nil { + return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()} + } + if err := s.deps.Hub.Send(ctx, hostID, env); err != nil { + // Roll the job to failed so we don't leak a queued row. + _ = s.deps.Store.MarkJobFinished(ctx, jobID, "failed", -1, nil, err.Error(), time.Now().UTC()) + return dispatchHostUpdateResult{ + Code: "host_offline", Status: stdhttp.StatusConflict, Msg: err.Error(), + } + } + if s.deps.UpdateWatcher != nil { + s.deps.UpdateWatcher.Track(jobID, hostID) + } + + auditPayload, _ := json.Marshal(map[string]string{ + "job_id": jobID, + "target_version": version.Version, + }) + _ = s.deps.Store.AppendAudit(ctx, store.AuditEntry{ + ID: ulid.Make().String(), + UserID: actorID, + Actor: actorKind, + Action: "host.update_dispatched", + TargetKind: ptr("host"), + TargetID: &hostID, + TS: now, + Payload: auditPayload, + }) + + return dispatchHostUpdateResult{JobID: jobID} +} + +// handleHostUpdate is POST /api/hosts/{id}/update — JSON, admin-only. +func (s *Server) handleHostUpdate(w stdhttp.ResponseWriter, r *stdhttp.Request) { + user, ok := s.requireUser(r) + if !ok { + writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "") + return + } + hostID := chi.URLParam(r, "id") + if hostID == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_host_id", "") + return + } + actor := "user" + var actorID *string + if user != nil { + actorID = &user.ID + } + res := s.dispatchHostUpdate(r.Context(), hostID, actor, actorID) + if res.Code != "" { + writeJSONError(w, res.Status, res.Code, res.Msg) + return + } + writeJSON(w, stdhttp.StatusAccepted, map[string]string{"job_id": res.JobID}) +} + +// handleHostUpdateForm is the HTMX-friendly POST /hosts/{id}/update +// variant. On success it sets HX-Redirect to the job detail page; on +// pre-check failures it renders an inline error banner. +func (s *Server) handleHostUpdateForm(w stdhttp.ResponseWriter, r *stdhttp.Request) { + user, ok := s.requireUser(r) + if !ok { + stdhttp.Error(w, "unauthorised", stdhttp.StatusUnauthorized) + return + } + hostID := chi.URLParam(r, "id") + if hostID == "" { + stdhttp.Error(w, "missing host_id", stdhttp.StatusBadRequest) + return + } + actor := "user" + var actorID *string + if user != nil { + actorID = &user.ID + } + res := s.dispatchHostUpdate(r.Context(), hostID, actor, actorID) + if res.Code != "" { + // Inline banner for HTMX swaps. Mirrors what host_credentials + // returns on validation errors — small text/html fragment. + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(res.Status) + msg := hostUpdateErrorMessage(res.Code, res.Msg) + _, _ = w.Write([]byte(``)) + return + } + w.Header().Set("HX-Redirect", "/jobs/"+res.JobID) + w.WriteHeader(stdhttp.StatusOK) +} + +func hostUpdateErrorMessage(code, msg string) string { + switch code { + case "host_not_found": + return "Host not found." + case "host_offline": + return "Agent is offline; can't deliver the update command." + case "already_up_to_date": + return "Agent is already running the current version." + case "update_in_progress": + return "An update is already in progress for this host." + } + if msg != "" { + return msg + } + return "Update dispatch failed." +} + +// htmlEscape is a minimal HTML-attr-safe escaper. Avoids pulling html/template +// for a one-shot inline banner. +func htmlEscape(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + switch s[i] { + case '&': + out = append(out, []byte("&")...) + case '<': + out = append(out, []byte("<")...) + case '>': + out = append(out, []byte(">")...) + case '"': + out = append(out, []byte(""")...) + default: + out = append(out, s[i]) + } + } + return string(out) +} diff --git a/internal/server/http/host_update_test.go b/internal/server/http/host_update_test.go new file mode 100644 index 0000000..30cc0ce --- /dev/null +++ b/internal/server/http/host_update_test.go @@ -0,0 +1,270 @@ +// host_update_test.go — covers POST /api/hosts/{id}/update. +package http + +import ( + "context" + "encoding/json" + "io" + stdhttp "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" + "gitea.dcglab.co.uk/steve/restic-manager/internal/version" +) + +// stubWatcher records Track calls so tests can assert the watcher was +// notified. +type stubWatcher struct { + mu sync.Mutex + tracked []string // hostIDs +} + +func (s *stubWatcher) Track(_, hostID string) { + s.mu.Lock() + defer s.mu.Unlock() + s.tracked = append(s.tracked, hostID) +} + +func TestHostUpdateHappyPath(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + watcher := &stubWatcher{} + srv.deps.UpdateWatcher = watcher + hostID, token := enrolHostForWS(t, srv, st, "upd-host") + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "upd-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Force a version mismatch so the dispatch isn't short-circuited. + if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { + t.Fatalf("mark hello: %v", err) + } + + cookie := loginAsAdmin(t, st) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusAccepted { + t.Fatalf("status: got %d, want 202", res.StatusCode) + } + var out struct { + JobID string `json:"job_id"` + } + if err := json.NewDecoder(res.Body).Decode(&out); err != nil { + t.Fatalf("decode: %v", err) + } + if out.JobID == "" { + t.Fatal("missing job_id in response") + } + + // command.update envelope arrives. + deadline := time.Now().Add(2 * time.Second) + var got api.Envelope + for time.Now().Before(deadline) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + mt, raw, rerr := c.Read(ctx) + cancel() + if rerr != nil { + break + } + if mt != websocket.MessageText { + continue + } + if !strings.Contains(string(raw), `"command.update"`) { + continue + } + _ = json.Unmarshal(raw, &got) + break + } + if got.Type != api.MsgCommandUpdate { + t.Fatal("never received command.update envelope") + } + var cp api.CommandUpdatePayload + if err := got.UnmarshalPayload(&cp); err != nil { + t.Fatalf("payload: %v", err) + } + if cp.JobID != out.JobID { + t.Fatalf("payload job_id: got %q want %q", cp.JobID, out.JobID) + } + + // Watcher tracked. + watcher.mu.Lock() + defer watcher.mu.Unlock() + if len(watcher.tracked) != 1 || watcher.tracked[0] != hostID { + t.Fatalf("watcher tracked: %v", watcher.tracked) + } + + // Audit row exists. + var n int + if err := st.DB().QueryRow( + `SELECT COUNT(*) FROM audit_log WHERE action = 'host.update_dispatched' AND target_id = ?`, + hostID).Scan(&n); err != nil { + t.Fatalf("audit count: %v", err) + } + if n != 1 { + t.Fatalf("audit rows: got %d, want 1", n) + } +} + +func TestHostUpdateNotFound(t *testing.T) { + t.Parallel() + _, ts, st := rawTestServer(t) + cookie := loginAsAdmin(t, st) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/no-such/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusNotFound { + t.Fatalf("status: got %d want 404", res.StatusCode) + } +} + +func TestHostUpdateOffline(t *testing.T) { + t.Parallel() + _, ts, st := rawTestServer(t) + hostID := ulid.Make().String() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: "off", OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, "deadbeef", ""); err != nil { + t.Fatalf("create: %v", err) + } + cookie := loginAsAdmin(t, st) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusConflict { + t.Fatalf("status: got %d want 409", res.StatusCode) + } + body := readJSONError(t, res.Body) + if body.Code != "host_offline" { + t.Fatalf("code: %q", body.Code) + } +} + +func TestHostUpdateAlreadyUpToDate(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "uptodate-host") + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "uptodate-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Force agent_version == version.Version. + if err := st.MarkHostHello(context.Background(), hostID, version.Version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { + t.Fatalf("mark hello: %v", err) + } + + cookie := loginAsAdmin(t, st) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusConflict { + t.Fatalf("status: got %d want 409", res.StatusCode) + } + body := readJSONError(t, res.Body) + if body.Code != "already_up_to_date" { + t.Fatalf("code: %q", body.Code) + } +} + +func TestHostUpdateInProgress(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "inprog-host") + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "inprog-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { + t.Fatalf("mark hello: %v", err) + } + + // Pre-seed an in-flight update job. + jobID := ulid.Make().String() + if err := st.CreateJob(context.Background(), store.Job{ + ID: jobID, HostID: hostID, Kind: "update", + ActorKind: "user", CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("seed job: %v", err) + } + + cookie := loginAsAdmin(t, st) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusConflict { + t.Fatalf("status: got %d want 409", res.StatusCode) + } + body := readJSONError(t, res.Body) + if body.Code != "update_in_progress" { + t.Fatalf("code: %q", body.Code) + } +} + +func TestHostUpdateRBAC(t *testing.T) { + t.Parallel() + _, ts, st := rawTestServer(t) + hostID := ulid.Make().String() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: "rbac-host", OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, "deadbeef", ""); err != nil { + t.Fatalf("create: %v", err) + } + for _, role := range []store.Role{store.RoleViewer, store.RoleOperator} { + role := role + t.Run(string(role), func(t *testing.T) { + cookie := loginAsRole(t, st, role) + req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) + req.AddCookie(cookie) + res, err := stdhttp.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusForbidden { + t.Fatalf("status for %s: got %d want 403", role, res.StatusCode) + } + }) + } +} + +type jsonErrBody struct { + Code string `json:"code"` + Message string `json:"message,omitempty"` +} + +func readJSONError(t *testing.T, body io.Reader) jsonErrBody { + t.Helper() + var out jsonErrBody + if err := json.NewDecoder(body).Decode(&out); err != nil { + t.Fatalf("decode error body: %v", err) + } + return out +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index f9c42c5..67aeaf4 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -39,6 +39,13 @@ type Deps struct { // NotificationHub (optional, wired in G1) is used by the test-fire // endpoint to dispatch a single synthetic payload through a channel. NotificationHub *notification.Hub + // UpdateWatcher tracks in-flight agent self-update dispatches and + // reconciles them against incoming hello envelopes. Optional; + // nil = no-op (handlers degrade by skipping the Track call). + UpdateWatcher UpdateWatcher + // FleetWorker drives the rolling fleet-update worker. Optional; + // nil = fleet update endpoints (P6-15) report unavailable. + FleetWorker FleetWorker // Version is the binary's build version, surfaced in the chrome. // Empty falls back to "dev". Version string @@ -125,7 +132,7 @@ func (s *Server) routes(r chi.Router) { r.Get("/install/*", s.handleInstallAsset) r.Get("/api/version", s.handleVersion) if s.deps.Hub != nil { - r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{ + hd := ws.HandlerDeps{ Hub: s.deps.Hub, Store: s.deps.Store, JobHub: s.deps.JobHub, @@ -133,7 +140,11 @@ func (s *Server) routes(r chi.Router) { OnHello: s.onAgentHello, OnScheduleAck: s.applyScheduleAck, OnScheduleFire: s.dispatchScheduledJob, - })) + } + if w, ok := s.deps.UpdateWatcher.(*ws.UpdateWatcher); ok && w != nil { + hd.UpdateWatcher = w + } + r.Mount("/ws/agent", ws.AgentHandler(hd)) } r.Get("/ws/agent/pending", s.handlePendingWS) r.Mount("/static/", staticHandler()) @@ -271,6 +282,9 @@ func (s *Server) routes(r chi.Router) { r.Group(func(r chi.Router) { r.Use(s.requireRole(store.RoleAdmin)) + r.Post("/api/hosts/{id}/update", s.handleHostUpdate) + r.Post("/hosts/{id}/update", s.handleHostUpdateForm) + r.Get("/api/users", s.handleAPIUsersList) r.Post("/api/users", s.handleAPIUserCreate) r.Get("/api/users/{id}", s.handleAPIUserGet) @@ -322,6 +336,27 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.srv.Shutdown(ctx) } +// SetFleetWorker installs the fleet-update worker post-construction. +// Used to break the wiring loop in cmd/server (the worker depends on a +// dispatcher that delegates back into the server's host-update path). +func (s *Server) SetFleetWorker(fw FleetWorker) { s.deps.FleetWorker = fw } + +// DispatchHostUpdate is the public entry point for callers (the fleet +// worker) that need to drive the same dispatch path the HTTP handler +// uses, without going through HTTP. Returns the structured result so +// the caller can map error codes to its own status enum. +func (s *Server) DispatchHostUpdate(ctx context.Context, hostID, actorUserID string) (jobID string, code string, err error) { + var actorID *string + if actorUserID != "" { + actorID = &actorUserID + } + res := s.dispatchHostUpdate(ctx, hostID, "user", actorID) + if res.Code != "" { + return res.JobID, res.Code, nil + } + return res.JobID, "", nil +} + // Addr returns the configured listen address. Useful in tests when // the caller passes :0 to get a random port. func (s *Server) Addr() string { return s.srv.Addr } diff --git a/internal/server/ws/handler.go b/internal/server/ws/handler.go index df74332..312f568 100644 --- a/internal/server/ws/handler.go +++ b/internal/server/ws/handler.go @@ -16,6 +16,7 @@ import ( "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" + "gitea.dcglab.co.uk/steve/restic-manager/internal/version" ) // HandlerDeps is the set of collaborators the agent WS handler needs. @@ -26,6 +27,9 @@ type HandlerDeps struct { // AlertEngine receives job-finished and host-online events so the // alert engine can evaluate its rules. Optional; nil = no-op. AlertEngine *alert.Engine + // UpdateWatcher reconciles in-flight agent-update dispatches against + // hello envelopes. Optional; nil = no-op. + UpdateWatcher *UpdateWatcher // OnHello is called once per successful hello, after the host row // has been touched and the conn registered. Used by the HTTP // layer to push host_credentials down as a config.update before @@ -147,6 +151,9 @@ func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) if deps.AlertEngine != nil { deps.AlertEngine.NotifyHostOnline(hostID) } + if deps.UpdateWatcher != nil { + deps.UpdateWatcher.OnHello(ctx, hostID, helloPayload.AgentVersion, version.Version) + } deps.Hub.Register(hostID, c) defer deps.Hub.Unregister(hostID, c) diff --git a/internal/server/ws/update_watch.go b/internal/server/ws/update_watch.go new file mode 100644 index 0000000..be2fef8 --- /dev/null +++ b/internal/server/ws/update_watch.go @@ -0,0 +1,151 @@ +package ws + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// updateTimeout bounds how long the watcher waits for an agent to come +// back with its new version after a command.update dispatch. var (not +// const) so tests can shrink it. +var updateTimeout = 90 * time.Second + +// AlertRaiser is the slim subset of *alert.Engine the update watcher +// touches. Defined here (not in the alert package) so the dependency +// arrow points the right way. +type AlertRaiser interface { + RaiseUpdateFailed(ctx context.Context, hostID, jobID, reason string, when time.Time) + ResolveUpdateFailed(ctx context.Context, hostID string, when time.Time) +} + +// UpdateWatcher tracks in-flight agent-update dispatches and reconciles +// them against incoming hello envelopes. Entries land on Track and +// resolve via OnHello (success path) or the periodic sweep (timeout). +type UpdateWatcher struct { + store *store.Store + alerts AlertRaiser + + mu sync.Mutex + entries map[string]*updateEntry // hostID → entry + + tickPeriod time.Duration +} + +type updateEntry struct { + jobID string + startedAt time.Time + // terminated is set once the entry has reached a terminal state so + // late OnHellos don't resurrect it. + terminated bool +} + +// NewUpdateWatcher builds an unstarted watcher. Call Run in a goroutine +// to start the periodic sweep. +func NewUpdateWatcher(st *store.Store, alerts AlertRaiser) *UpdateWatcher { + return &UpdateWatcher{ + store: st, + alerts: alerts, + entries: make(map[string]*updateEntry), + tickPeriod: 5 * time.Second, + } +} + +// Track registers a freshly-dispatched update job. A subsequent Track +// for the same host replaces the prior entry (last-write-wins). +func (w *UpdateWatcher) Track(jobID, hostID string) { + if w == nil { + return + } + w.mu.Lock() + w.entries[hostID] = &updateEntry{jobID: jobID, startedAt: time.Now()} + w.mu.Unlock() +} + +// OnHello is called by the WS handler after a successful hello has been +// persisted. If a tracked update for the host matches the targetVersion, +// the job is marked succeeded and any open update_failed alert is +// auto-resolved. A non-matching version is a no-op (the watcher keeps +// waiting until the timeout). +func (w *UpdateWatcher) OnHello(ctx context.Context, hostID, agentVersion, targetVersion string) { + if w == nil { + return + } + w.mu.Lock() + e, ok := w.entries[hostID] + if !ok || e.terminated { + w.mu.Unlock() + return + } + if agentVersion != targetVersion { + // Not the version we asked for — keep waiting. + w.mu.Unlock() + return + } + e.terminated = true + jobID := e.jobID + delete(w.entries, hostID) + w.mu.Unlock() + + now := time.Now().UTC() + if err := w.store.MarkJobFinished(ctx, jobID, "succeeded", 0, nil, "", now); err != nil { + slog.Warn("ws update watcher: mark succeeded", "job_id", jobID, "host_id", hostID, "err", err) + } + if w.alerts != nil { + w.alerts.ResolveUpdateFailed(ctx, hostID, now) + } +} + +// Run drives the periodic sweep. Returns when ctx is done. +func (w *UpdateWatcher) Run(ctx context.Context) { + if w == nil { + return + } + t := time.NewTicker(w.tickPeriod) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case now := <-t.C: + w.sweep(ctx, now) + } + } +} + +func (w *UpdateWatcher) sweep(ctx context.Context, now time.Time) { + type expired struct { + hostID string + jobID string + age time.Duration + } + var toFail []expired + w.mu.Lock() + for hostID, e := range w.entries { + if e.terminated { + continue + } + if now.Sub(e.startedAt) >= updateTimeout { + toFail = append(toFail, expired{hostID: hostID, jobID: e.jobID, age: now.Sub(e.startedAt)}) + e.terminated = true + delete(w.entries, hostID) + } + } + w.mu.Unlock() + + for _, x := range toFail { + reason := fmt.Sprintf("timeout: agent did not reconnect within %s", updateTimeout) + stamp := now.UTC() + errMsg := reason + if err := w.store.MarkJobFinished(ctx, x.jobID, "failed", -1, nil, errMsg, stamp); err != nil { + slog.Warn("ws update watcher: mark failed", "job_id", x.jobID, "host_id", x.hostID, "err", err) + } + if w.alerts != nil { + w.alerts.RaiseUpdateFailed(ctx, x.hostID, x.jobID, reason, stamp) + } + } +} diff --git a/internal/server/ws/update_watch_test.go b/internal/server/ws/update_watch_test.go new file mode 100644 index 0000000..4081501 --- /dev/null +++ b/internal/server/ws/update_watch_test.go @@ -0,0 +1,161 @@ +package ws + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +type fakeAlerts struct { + mu sync.Mutex + raised []string // hostIDs + resolved []string + reasons []string +} + +func (f *fakeAlerts) RaiseUpdateFailed(_ context.Context, hostID, _ /*jobID*/, reason string, _ time.Time) { + f.mu.Lock() + defer f.mu.Unlock() + f.raised = append(f.raised, hostID) + f.reasons = append(f.reasons, reason) +} + +func (f *fakeAlerts) ResolveUpdateFailed(_ context.Context, hostID string, _ time.Time) { + f.mu.Lock() + defer f.mu.Unlock() + f.resolved = append(f.resolved, hostID) +} + +func seedJob(t *testing.T, st *store.Store, hostID string) string { + t.Helper() + jobID := ulid.Make().String() + if err := st.CreateJob(context.Background(), store.Job{ + ID: jobID, HostID: hostID, Kind: "update", + ActorKind: "user", CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create job: %v", err) + } + return jobID +} + +func TestUpdateWatcherOnHelloSuccess(t *testing.T) { + st := openWSTestStore(t) + hostID := ulid.Make().String() + seedHostWS(t, st, hostID) + jobID := seedJob(t, st, hostID) + + a := &fakeAlerts{} + w := NewUpdateWatcher(st, a) + w.Track(jobID, hostID) + + w.OnHello(context.Background(), hostID, "v2", "v2") + + job, err := st.GetJob(context.Background(), jobID) + if err != nil { + t.Fatalf("get job: %v", err) + } + if job.Status != "succeeded" { + t.Fatalf("status: got %q want succeeded", job.Status) + } + a.mu.Lock() + defer a.mu.Unlock() + if len(a.resolved) != 1 || a.resolved[0] != hostID { + t.Fatalf("resolve calls: %v", a.resolved) + } + if len(a.raised) != 0 { + t.Fatalf("unexpected raises: %v", a.raised) + } +} + +func TestUpdateWatcherTimeout(t *testing.T) { + prev := updateTimeout + updateTimeout = 50 * time.Millisecond + t.Cleanup(func() { updateTimeout = prev }) + + st := openWSTestStore(t) + hostID := ulid.Make().String() + seedHostWS(t, st, hostID) + jobID := seedJob(t, st, hostID) + + a := &fakeAlerts{} + w := NewUpdateWatcher(st, a) + w.Track(jobID, hostID) + + time.Sleep(80 * time.Millisecond) + w.sweep(context.Background(), time.Now()) + + job, err := st.GetJob(context.Background(), jobID) + if err != nil { + t.Fatalf("get job: %v", err) + } + if job.Status != "failed" { + t.Fatalf("status: got %q want failed", job.Status) + } + a.mu.Lock() + defer a.mu.Unlock() + if len(a.raised) != 1 || a.raised[0] != hostID { + t.Fatalf("raise calls: %v", a.raised) + } + if len(a.reasons) == 0 || a.reasons[0] == "" { + t.Fatalf("missing reason") + } +} + +func TestUpdateWatcherMismatchedVersionNoOp(t *testing.T) { + st := openWSTestStore(t) + hostID := ulid.Make().String() + seedHostWS(t, st, hostID) + jobID := seedJob(t, st, hostID) + + a := &fakeAlerts{} + w := NewUpdateWatcher(st, a) + w.Track(jobID, hostID) + + w.OnHello(context.Background(), hostID, "v1", "v2") + + job, _ := st.GetJob(context.Background(), jobID) + if job.Status == "succeeded" || job.Status == "failed" { + t.Fatalf("status flipped on mismatched hello: %q", job.Status) + } + a.mu.Lock() + defer a.mu.Unlock() + if len(a.raised) != 0 || len(a.resolved) != 0 { + t.Fatalf("unexpected alert calls raised=%v resolved=%v", a.raised, a.resolved) + } +} + +func TestUpdateWatcherHelloAfterTimeoutIsNoOp(t *testing.T) { + prev := updateTimeout + updateTimeout = 50 * time.Millisecond + t.Cleanup(func() { updateTimeout = prev }) + + st := openWSTestStore(t) + hostID := ulid.Make().String() + seedHostWS(t, st, hostID) + jobID := seedJob(t, st, hostID) + + a := &fakeAlerts{} + w := NewUpdateWatcher(st, a) + w.Track(jobID, hostID) + + time.Sleep(80 * time.Millisecond) + w.sweep(context.Background(), time.Now()) + + // Hello arrives after sweep — entry already gone, must be no-op. + w.OnHello(context.Background(), hostID, "v2", "v2") + + job, _ := st.GetJob(context.Background(), jobID) + if job.Status != "failed" { + t.Fatalf("status flipped from failed → %q", job.Status) + } + a.mu.Lock() + defer a.mu.Unlock() + if len(a.resolved) != 0 { + t.Fatalf("late hello triggered ResolveUpdateFailed: %v", a.resolved) + } +} diff --git a/internal/store/alerts.go b/internal/store/alerts.go index b12d6fa..f16b9bc 100644 --- a/internal/store/alerts.go +++ b/internal/store/alerts.go @@ -77,6 +77,56 @@ func (s *Store) RaiseOrTouch(ctx context.Context, hostID, kind, dedupKey, severi return id, true, nil } +// RaiseOrTouchSystem is the host-less variant of RaiseOrTouch — the +// alert row's host_id is stored as NULL, so the FK to hosts is bypassed. +// Used by fleet-wide alerts (e.g. fleet_update_halted) where the +// failure surface isn't pinned to a single host. +func (s *Store) RaiseOrTouchSystem(ctx context.Context, kind, dedupKey, severity, message string, when time.Time) (id string, didRaise bool, err error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return "", false, fmt.Errorf("store: begin: %w", err) + } + defer func() { _ = tx.Rollback() }() + + row := tx.QueryRowContext(ctx, + `SELECT id FROM alerts + WHERE host_id IS NULL AND kind = ? AND dedup_key = ? AND resolved_at IS NULL + LIMIT 1`, + kind, dedupKey) + var existing string + switch err := row.Scan(&existing); { + case err == nil: + _, uerr := tx.ExecContext(ctx, + `UPDATE alerts SET last_seen_at = ?, message = ? WHERE id = ?`, + when.UTC().Format(time.RFC3339Nano), message, existing) + if uerr != nil { + return "", false, fmt.Errorf("store: touch alert: %w", uerr) + } + if err := tx.Commit(); err != nil { + return "", false, err + } + return existing, false, nil + case errors.Is(err, sql.ErrNoRows): + // fall through to insert + default: + return "", false, fmt.Errorf("store: lookup alert: %w", err) + } + + id = ulid.Make().String() + whenStr := when.UTC().Format(time.RFC3339Nano) + _, err = tx.ExecContext(ctx, + `INSERT INTO alerts (id, host_id, kind, dedup_key, severity, message, created_at, last_seen_at) + VALUES (?, NULL, ?, ?, ?, ?, ?, ?)`, + id, kind, dedupKey, severity, message, whenStr, whenStr) + if err != nil { + return "", false, fmt.Errorf("store: insert alert: %w", err) + } + if err := tx.Commit(); err != nil { + return "", false, err + } + return id, true, nil +} + // refreshHostOpenAlertCount recomputes hosts.open_alert_count from the // alerts table for one host. Self-healing: idempotent and survives // out-of-order edits. Best-effort — errors are returned but callers