p6-01/02: agent self-update + fleet update server cluster
- alert: update_failed (per-host, dedup=hostID) + fleet_update_halted
(system-scoped, host_id NULL via new RaiseOrTouchSystem helper).
- ws: UpdateWatcher tracks in-flight command.update dispatches and
reconciles them against incoming hello envelopes — success path
marks the job succeeded and auto-resolves the alert; 90s timeout
marks the job failed and raises update_failed.
- http: POST /api/hosts/{id}/update (admin-only JSON) + the HTMX
/hosts/{id}/update form variant. Pre-checks: host exists, online,
agent_version != current, no running update job. Refactored core
into Server.dispatchHostUpdate so the fleet worker can share it
without going through HTTP.
- fleetupdate: rolling worker iterating through host slots, halting
on first failure and raising fleet_update_halted. Polling-based
version-match (re-read hosts.agent_version every 1s up to 95s) —
no extra plumbing into the WS hello path. At-most-one-running is
enforced at the store layer (ErrFleetUpdateRunning).
- cmd/server: wire UpdateWatcher and FleetWorker into the main
goroutine; the worker uses a small serverDispatcher adapter that
delegates back into Server.DispatchHostUpdate.
Tests: watcher (success/timeout/mismatch/late-hello), HTTP endpoint
(happy + four pre-check branches + RBAC), worker (two-host happy,
timeout-halt, host-offline-halt, already-at-target skip, cancel
mid-run, double-Start guard).
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
|||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/fleetupdate"
|
||||||
rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http"
|
rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/maintenance"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/maintenance"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
||||||
@@ -91,6 +92,7 @@ func run() error {
|
|||||||
|
|
||||||
notifHub := notification.NewHub(st, aead, cfg.BaseURL)
|
notifHub := notification.NewHub(st, aead, cfg.BaseURL)
|
||||||
alertEngine := alert.NewEngine(st, notifHub)
|
alertEngine := alert.NewEngine(st, notifHub)
|
||||||
|
updateWatcher := ws.NewUpdateWatcher(st, alertEngine)
|
||||||
|
|
||||||
renderer, err := ui.New()
|
renderer, err := ui.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,6 +118,7 @@ func run() error {
|
|||||||
JobHub: jobHub,
|
JobHub: jobHub,
|
||||||
AlertEngine: alertEngine,
|
AlertEngine: alertEngine,
|
||||||
NotificationHub: notifHub,
|
NotificationHub: notifHub,
|
||||||
|
UpdateWatcher: updateWatcher,
|
||||||
UI: renderer,
|
UI: renderer,
|
||||||
Version: version,
|
Version: version,
|
||||||
OIDC: oidcClient,
|
OIDC: oidcClient,
|
||||||
@@ -147,10 +150,17 @@ func run() error {
|
|||||||
|
|
||||||
srv := rmhttp.New(deps)
|
srv := rmhttp.New(deps)
|
||||||
|
|
||||||
|
// Fleet-update worker — built after the HTTP server because the
|
||||||
|
// dispatcher delegates back into srv.DispatchHostUpdate.
|
||||||
|
fleetWorker := fleetupdate.NewWorker(st, hub,
|
||||||
|
&serverDispatcher{srv: srv}, alertEngine)
|
||||||
|
srv.SetFleetWorker(fleetWorker)
|
||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
go alertEngine.Run(ctx)
|
go alertEngine.Run(ctx)
|
||||||
|
go updateWatcher.Run(ctx)
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -243,3 +253,12 @@ func run() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serverDispatcher adapts the http.Server's DispatchHostUpdate method
|
||||||
|
// to the fleetupdate.Dispatcher interface. Lives in main so the
|
||||||
|
// http and fleetupdate packages don't need to know about each other.
|
||||||
|
type serverDispatcher struct{ srv *rmhttp.Server }
|
||||||
|
|
||||||
|
func (d *serverDispatcher) DispatchUpdate(ctx context.Context, hostID, actorUserID string) (string, string, error) {
|
||||||
|
return d.srv.DispatchHostUpdate(ctx, hostID, actorUserID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package alert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Alert-kind constants for P6 self-update flows.
|
||||||
|
const (
|
||||||
|
// KindUpdateFailed is raised when an agent fails to come back with
|
||||||
|
// the expected version after a command.update dispatch (timeout or
|
||||||
|
// version-mismatch). Resolved by a subsequent matching hello.
|
||||||
|
KindUpdateFailed = "update_failed"
|
||||||
|
|
||||||
|
// KindFleetUpdateHalted is raised when the fleet-update worker
|
||||||
|
// stops mid-run because a host failed to update or went offline.
|
||||||
|
// Host-less alert (system-scoped). Manually resolved by an admin.
|
||||||
|
KindFleetUpdateHalted = "fleet_update_halted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RaiseUpdateFailed records a per-host update failure. dedupKey is the
|
||||||
|
// hostID so a re-dispatch on the same host touches the existing alert
|
||||||
|
// rather than spawning a duplicate.
|
||||||
|
func (e *Engine) RaiseUpdateFailed(ctx context.Context, hostID, jobID, reason string, when time.Time) {
|
||||||
|
msg := fmt.Sprintf("Agent update failed (job %s): %s", jobID, reason)
|
||||||
|
e.raiseAndNotify(ctx, hostID, KindUpdateFailed, hostID, "warning", msg, when)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUpdateFailed clears any open update_failed alert for hostID.
|
||||||
|
// Called from the WS hello path when the agent reconnects with the
|
||||||
|
// target version.
|
||||||
|
func (e *Engine) ResolveUpdateFailed(ctx context.Context, hostID string, when time.Time) {
|
||||||
|
e.resolveAndNotify(ctx, hostID, KindUpdateFailed, hostID, when)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RaiseFleetUpdateHalted is host-less — the fleet update is a
|
||||||
|
// system-level concept. We persist it via the dedicated host-less
|
||||||
|
// alert path so the alerts table's host_id column carries NULL.
|
||||||
|
func (e *Engine) RaiseFleetUpdateHalted(ctx context.Context, fleetUpdateID, reason string, when time.Time) {
|
||||||
|
msg := fmt.Sprintf("Fleet update %s halted: %s", fleetUpdateID, reason)
|
||||||
|
id, didRaise, err := e.store.RaiseOrTouchSystem(ctx, KindFleetUpdateHalted, fleetUpdateID, "warning", msg, when)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("alert: raise fleet_update_halted", "fu_id", fleetUpdateID, "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !didRaise {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go e.hub.Dispatch(ctx, notification.Payload{
|
||||||
|
Event: notification.EventRaised,
|
||||||
|
AlertID: id,
|
||||||
|
Severity: "warning",
|
||||||
|
Kind: KindFleetUpdateHalted,
|
||||||
|
HostID: "",
|
||||||
|
HostName: "",
|
||||||
|
Message: msg,
|
||||||
|
RaisedAt: when,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
// Package fleetupdate drives a rolling, sequential agent self-update
|
||||||
|
// over a list of hosts. One worker goroutine per Start() call (gated
|
||||||
|
// at the store layer to at-most-one-running-fleet-update).
|
||||||
|
package fleetupdate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hub is the slim "is this host connected?" surface.
|
||||||
|
type Hub interface {
|
||||||
|
Connected(hostID string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatcher sends one command.update envelope. The implementer also
|
||||||
|
// creates the jobs row, writes audit, and registers with the update
|
||||||
|
// watcher. Pre-checks are the dispatcher's responsibility — the worker
|
||||||
|
// passes through whatever error it returns.
|
||||||
|
type Dispatcher interface {
|
||||||
|
DispatchUpdate(ctx context.Context, hostID string, actorUserID string) (jobID string, code string, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlertRaiser is the slim view of the alert engine's host-less raise
|
||||||
|
// path. Used to emit fleet_update_halted on first failure.
|
||||||
|
type AlertRaiser interface {
|
||||||
|
RaiseFleetUpdateHalted(ctx context.Context, fleetUpdateID, reason string, when time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Worker is the long-lived fleet-update orchestrator. There is at most
|
||||||
|
// one *running* fleet update at a time (enforced by the store).
|
||||||
|
type Worker struct {
|
||||||
|
store *store.Store
|
||||||
|
hub Hub
|
||||||
|
disp Dispatcher
|
||||||
|
alerts AlertRaiser
|
||||||
|
|
||||||
|
// targetVersion is the version every dispatched agent is expected
|
||||||
|
// to come back with. Captured at Start time to avoid drift.
|
||||||
|
targetVersion string
|
||||||
|
|
||||||
|
// pollPeriod controls the cadence at which the worker re-reads the
|
||||||
|
// host row to check for the version transition. Exposed for tests.
|
||||||
|
pollPeriod time.Duration
|
||||||
|
// hostTimeout bounds how long the worker waits for one host to
|
||||||
|
// reach the target version before halting.
|
||||||
|
hostTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWorker builds an unstarted worker. targetVersion is set on each
|
||||||
|
// Start call; the values here are defaults.
|
||||||
|
func NewWorker(st *store.Store, hub Hub, disp Dispatcher, alerts AlertRaiser) *Worker {
|
||||||
|
return &Worker{
|
||||||
|
store: st,
|
||||||
|
hub: hub,
|
||||||
|
disp: disp,
|
||||||
|
alerts: alerts,
|
||||||
|
pollPeriod: 1 * time.Second,
|
||||||
|
hostTimeout: 95 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start creates the parent + child rows, then spawns the per-host
|
||||||
|
// worker goroutine. Returns the new fleet_update_id on success.
|
||||||
|
// store.ErrFleetUpdateRunning bubbles up unchanged.
|
||||||
|
func (w *Worker) Start(ctx context.Context, userID, targetVersion string, hostIDs []string) (string, error) {
|
||||||
|
if userID == "" || targetVersion == "" {
|
||||||
|
return "", errors.New("fleetupdate: userID and targetVersion required")
|
||||||
|
}
|
||||||
|
if len(hostIDs) == 0 {
|
||||||
|
return "", errors.New("fleetupdate: at least one host required")
|
||||||
|
}
|
||||||
|
fuID := ulid.Make().String()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := w.store.CreateFleetUpdate(ctx, store.FleetUpdate{
|
||||||
|
ID: fuID,
|
||||||
|
StartedAt: now,
|
||||||
|
StartedByUserID: userID,
|
||||||
|
TargetVersion: targetVersion,
|
||||||
|
Status: "running",
|
||||||
|
}, hostIDs); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The goroutine outlives the request that started it; carry a
|
||||||
|
// detached context so an HTTP-handler ctx cancel doesn't abort
|
||||||
|
// the long roll.
|
||||||
|
bg := context.WithoutCancel(ctx)
|
||||||
|
go w.run(bg, fuID, userID, targetVersion)
|
||||||
|
return fuID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel marks the fleet update cancelled. The running goroutine
|
||||||
|
// observes the new status on its next pre-check and exits without
|
||||||
|
// dispatching further hosts. The currently-dispatched job is left to
|
||||||
|
// finish on its own — cancelling agent-side is out of scope for v1.
|
||||||
|
func (w *Worker) Cancel(ctx context.Context, fuID string) error {
|
||||||
|
return w.store.CancelFleetUpdate(ctx, fuID, time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
// run is the per-host loop. Halts on first failure; emits one alert
|
||||||
|
// on transition.
|
||||||
|
func (w *Worker) run(ctx context.Context, fuID, userID, targetVersion string) {
|
||||||
|
w.targetVersion = targetVersion
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Check the parent row's status — picks up Cancel.
|
||||||
|
fu, err := w.store.ActiveFleetUpdate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("fleetupdate: read active", "fu_id", fuID, "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fu == nil || fu.ID != fuID {
|
||||||
|
// Cancelled, halted, or completed externally. Done.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pending, err := w.store.ListPendingFleetUpdateHosts(ctx, fuID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("fleetupdate: list pending", "fu_id", fuID, "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(pending) == 0 {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := w.store.CompleteFleetUpdate(ctx, fuID, now); err != nil {
|
||||||
|
slog.Warn("fleetupdate: complete", "fu_id", fuID, "err", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next := pending[0]
|
||||||
|
w.processHost(ctx, fuID, userID, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processHost handles one host slot. Marks it skipped, succeeded, or
|
||||||
|
// failed (and halts the fleet on failure).
|
||||||
|
func (w *Worker) processHost(ctx context.Context, fuID, userID string, slot store.FleetUpdateHost) {
|
||||||
|
hostID := slot.HostID
|
||||||
|
_ = w.store.SetFleetUpdateCurrentHost(ctx, fuID, hostID)
|
||||||
|
|
||||||
|
// Pre-flight: re-read the host. The dispatch path repeats most of
|
||||||
|
// these checks but doing them up-front lets us emit the right
|
||||||
|
// per-host status (skipped vs failed) without consuming a job row.
|
||||||
|
host, err := w.store.GetHost(ctx, hostID)
|
||||||
|
if err != nil || host == nil {
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "skipped", "host not found", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if host.AgentVersion != "" && host.AgentVersion == w.targetVersion {
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "skipped", "already at target version", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !w.hub.Connected(hostID) {
|
||||||
|
reason := fmt.Sprintf("host went offline: %s", hostID)
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, "")
|
||||||
|
w.halt(ctx, fuID, reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch.
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "running", "", "")
|
||||||
|
jobID, code, err := w.disp.DispatchUpdate(ctx, hostID, userID)
|
||||||
|
if err != nil || code != "" {
|
||||||
|
reason := dispatchErrorReason(code, err)
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, jobID)
|
||||||
|
w.halt(ctx, fuID, reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Poll until the host's recorded agent_version matches target, or
|
||||||
|
// timeout.
|
||||||
|
deadline := time.Now().Add(w.hostTimeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
// Honour cancellation between polls.
|
||||||
|
fu, err := w.store.ActiveFleetUpdate(ctx)
|
||||||
|
if err == nil && (fu == nil || fu.ID != fuID) {
|
||||||
|
// Cancelled mid-host; leave the slot in 'running' for the
|
||||||
|
// admin to inspect. No further dispatches.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(w.pollPeriod)
|
||||||
|
h, err := w.store.GetHost(ctx, hostID)
|
||||||
|
if err == nil && h != nil && h.AgentVersion == w.targetVersion {
|
||||||
|
if err := w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "succeeded", "", jobID); err != nil {
|
||||||
|
slog.Warn("fleetupdate: set succeeded", "fu_id", fuID, "host_id", hostID, "err", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reason := fmt.Sprintf("timeout waiting for %s to reach %s", hostID, w.targetVersion)
|
||||||
|
_ = w.store.SetFleetUpdateHostStatus(ctx, fuID, hostID, "failed", reason, jobID)
|
||||||
|
w.halt(ctx, fuID, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Worker) halt(ctx context.Context, fuID, reason string) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := w.store.HaltFleetUpdate(ctx, fuID, reason, now); err != nil {
|
||||||
|
slog.Warn("fleetupdate: halt", "fu_id", fuID, "err", err)
|
||||||
|
}
|
||||||
|
if w.alerts != nil {
|
||||||
|
w.alerts.RaiseFleetUpdateHalted(ctx, fuID, reason, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatchErrorReason(code string, err error) string {
|
||||||
|
if code != "" {
|
||||||
|
return "dispatch failed: " + code
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err.Error()
|
||||||
|
}
|
||||||
|
return "dispatch failed"
|
||||||
|
}
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
package fleetupdate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeHub struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
online map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeHub) Connected(hostID string) bool {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.online[hostID]
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeDispatcher struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []string // host IDs
|
||||||
|
// after dispatch, set the host's agent_version to this on the
|
||||||
|
// store so the worker observes the version transition.
|
||||||
|
st *store.Store
|
||||||
|
target string
|
||||||
|
delayMS int
|
||||||
|
failOnHost map[string]string // host → error code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDispatcher) DispatchUpdate(ctx context.Context, hostID, _ string) (string, string, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.calls = append(f.calls, hostID)
|
||||||
|
if code, ok := f.failOnHost[hostID]; ok {
|
||||||
|
f.mu.Unlock()
|
||||||
|
return "", code, nil
|
||||||
|
}
|
||||||
|
st := f.st
|
||||||
|
target := f.target
|
||||||
|
delay := f.delayMS
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
jobID := ulid.Make().String()
|
||||||
|
if st != nil {
|
||||||
|
_ = st.CreateJob(context.Background(), store.Job{
|
||||||
|
ID: jobID, HostID: hostID, Kind: "update",
|
||||||
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if st != nil && target != "" {
|
||||||
|
go func() {
|
||||||
|
if delay > 0 {
|
||||||
|
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||||
|
}
|
||||||
|
_ = st.MarkHostHello(context.Background(), hostID, target, "0.17", api.CurrentProtocolVersion, time.Now().UTC())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return jobID, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type recAlert struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
reasons []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recAlert) RaiseFleetUpdateHalted(_ context.Context, _ string, reason string, _ time.Time) {
|
||||||
|
r.mu.Lock()
|
||||||
|
r.reasons = append(r.reasons, reason)
|
||||||
|
r.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openStore(t *testing.T) *store.Store {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = st.Close() })
|
||||||
|
return st
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateAdmin(t *testing.T, st *store.Store) string {
|
||||||
|
t.Helper()
|
||||||
|
uid := ulid.Make().String()
|
||||||
|
if err := st.CreateUser(context.Background(), store.User{
|
||||||
|
ID: uid, Username: "u-" + uid[:6],
|
||||||
|
PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("user: %v", err)
|
||||||
|
}
|
||||||
|
return uid
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateHost(t *testing.T, st *store.Store, name, version string) string {
|
||||||
|
t.Helper()
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
if err := st.CreateHost(context.Background(), store.Host{
|
||||||
|
ID: hostID, Name: name, OS: "linux", Arch: "amd64",
|
||||||
|
EnrolledAt: time.Now().UTC(),
|
||||||
|
}, "deadbeef-"+hostID, ""); err != nil {
|
||||||
|
t.Fatalf("host: %v", err)
|
||||||
|
}
|
||||||
|
if version != "" {
|
||||||
|
if err := st.MarkHostHello(context.Background(), hostID, version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
||||||
|
t.Fatalf("hello: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hostID
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForStatus(t *testing.T, st *store.Store, fuID, want string, timeout time.Duration) *store.FleetUpdate {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
fu, _, err := st.GetFleetUpdate(context.Background(), fuID)
|
||||||
|
if err == nil && fu != nil && fu.Status == want {
|
||||||
|
return fu
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("status never reached %q", want)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerTwoHostsBothSucceed(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v0")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: true, h2: true}}
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2", delayMS: 30}
|
||||||
|
alerts := &recAlert{}
|
||||||
|
w := NewWorker(st, hub, disp, alerts)
|
||||||
|
w.pollPeriod = 20 * time.Millisecond
|
||||||
|
w.hostTimeout = 2 * time.Second
|
||||||
|
|
||||||
|
fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
waitForStatus(t, st, fuID, "completed", 5*time.Second)
|
||||||
|
_, hosts, _ := st.GetFleetUpdate(context.Background(), fuID)
|
||||||
|
for _, h := range hosts {
|
||||||
|
if h.Status != "succeeded" {
|
||||||
|
t.Errorf("host %s status %q want succeeded", h.HostID, h.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if n := len(alerts.reasons); n != 0 {
|
||||||
|
t.Errorf("unexpected halt alert: %v", alerts.reasons)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerSecondHostTimesOutHalts(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v0")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
h3 := mustCreateHost(t, st, "h3", "v0")
|
||||||
|
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: true, h2: true, h3: true}}
|
||||||
|
// h1 dispatches normally (transitions to v2). h2 dispatch returns
|
||||||
|
// success but never transitions.
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20, failOnHost: map[string]string{
|
||||||
|
h2: "", // not a code-failure; simulate by clearing target on this disp run
|
||||||
|
}}
|
||||||
|
// Actually: drop h2 from the auto-transition by faking with a
|
||||||
|
// per-host store setter. Easiest: subclass via a wrapper.
|
||||||
|
_ = disp
|
||||||
|
customDisp := &perHostDispatcher{base: disp, st: st, target: "v2", noTransition: map[string]bool{h2: true}}
|
||||||
|
|
||||||
|
alerts := &recAlert{}
|
||||||
|
w := NewWorker(st, hub, customDisp, alerts)
|
||||||
|
w.pollPeriod = 20 * time.Millisecond
|
||||||
|
w.hostTimeout = 200 * time.Millisecond
|
||||||
|
|
||||||
|
fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2, h3})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
waitForStatus(t, st, fuID, "halted", 3*time.Second)
|
||||||
|
_, hosts, _ := st.GetFleetUpdate(context.Background(), fuID)
|
||||||
|
gotStatus := map[string]string{}
|
||||||
|
for _, h := range hosts {
|
||||||
|
gotStatus[h.HostID] = h.Status
|
||||||
|
}
|
||||||
|
if gotStatus[h1] != "succeeded" {
|
||||||
|
t.Errorf("h1: %q", gotStatus[h1])
|
||||||
|
}
|
||||||
|
if gotStatus[h2] != "failed" {
|
||||||
|
t.Errorf("h2: %q", gotStatus[h2])
|
||||||
|
}
|
||||||
|
if gotStatus[h3] != "pending" {
|
||||||
|
t.Errorf("h3: %q", gotStatus[h3])
|
||||||
|
}
|
||||||
|
alerts.mu.Lock()
|
||||||
|
defer alerts.mu.Unlock()
|
||||||
|
if len(alerts.reasons) != 1 {
|
||||||
|
t.Errorf("alert reasons: %v", alerts.reasons)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// perHostDispatcher lets a test omit the auto-transition for selected
|
||||||
|
// hosts so we can simulate timeout.
|
||||||
|
type perHostDispatcher struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
base *fakeDispatcher
|
||||||
|
st *store.Store
|
||||||
|
target string
|
||||||
|
noTransition map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *perHostDispatcher) DispatchUpdate(_ context.Context, hostID, _ string) (string, string, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
skip := p.noTransition[hostID]
|
||||||
|
p.mu.Unlock()
|
||||||
|
jobID := ulid.Make().String()
|
||||||
|
_ = p.st.CreateJob(context.Background(), store.Job{
|
||||||
|
ID: jobID, HostID: hostID, Kind: "update",
|
||||||
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
if !skip {
|
||||||
|
go func() {
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
_ = p.st.MarkHostHello(context.Background(), hostID, p.target, "0.17", api.CurrentProtocolVersion, time.Now().UTC())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return jobID, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerHostOfflineHalts(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v0")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: false, h2: true}}
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2"}
|
||||||
|
alerts := &recAlert{}
|
||||||
|
w := NewWorker(st, hub, disp, alerts)
|
||||||
|
w.pollPeriod = 20 * time.Millisecond
|
||||||
|
w.hostTimeout = 500 * time.Millisecond
|
||||||
|
|
||||||
|
fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
waitForStatus(t, st, fuID, "halted", 2*time.Second)
|
||||||
|
_, hosts, _ := st.GetFleetUpdate(context.Background(), fuID)
|
||||||
|
if hosts[0].Status != "failed" {
|
||||||
|
t.Errorf("h1 status: %q", hosts[0].Status)
|
||||||
|
}
|
||||||
|
if hosts[1].Status != "pending" {
|
||||||
|
t.Errorf("h2 status: %q", hosts[1].Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerAlreadyAtTargetSkipped(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v2")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: true, h2: true}}
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20}
|
||||||
|
alerts := &recAlert{}
|
||||||
|
w := NewWorker(st, hub, disp, alerts)
|
||||||
|
w.pollPeriod = 20 * time.Millisecond
|
||||||
|
w.hostTimeout = 2 * time.Second
|
||||||
|
|
||||||
|
fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
waitForStatus(t, st, fuID, "completed", 4*time.Second)
|
||||||
|
_, hosts, _ := st.GetFleetUpdate(context.Background(), fuID)
|
||||||
|
want := map[string]string{h1: "skipped", h2: "succeeded"}
|
||||||
|
for _, h := range hosts {
|
||||||
|
if h.Status != want[h.HostID] {
|
||||||
|
t.Errorf("host %s: got %q want %q", h.HostID, h.Status, want[h.HostID])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerCancelMidRun(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v0")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: true, h2: true}}
|
||||||
|
// h1's transition is delayed long enough that we can cancel
|
||||||
|
// before it lands; h2 should never be touched.
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2", delayMS: 500}
|
||||||
|
alerts := &recAlert{}
|
||||||
|
w := NewWorker(st, hub, disp, alerts)
|
||||||
|
w.pollPeriod = 50 * time.Millisecond
|
||||||
|
w.hostTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
// Give the worker a moment to dispatch h1.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if err := w.Cancel(context.Background(), fuID); err != nil {
|
||||||
|
t.Fatalf("cancel: %v", err)
|
||||||
|
}
|
||||||
|
waitForStatus(t, st, fuID, "cancelled", 2*time.Second)
|
||||||
|
|
||||||
|
// h2 should never be dispatched.
|
||||||
|
disp.mu.Lock()
|
||||||
|
defer disp.mu.Unlock()
|
||||||
|
for _, c := range disp.calls {
|
||||||
|
if c == h2 {
|
||||||
|
t.Errorf("h2 dispatched after cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerStartWhileActiveErrors(t *testing.T) {
|
||||||
|
st := openStore(t)
|
||||||
|
uid := mustCreateAdmin(t, st)
|
||||||
|
h1 := mustCreateHost(t, st, "h1", "v0")
|
||||||
|
h2 := mustCreateHost(t, st, "h2", "v0")
|
||||||
|
hub := &fakeHub{online: map[string]bool{h1: true, h2: true}}
|
||||||
|
disp := &fakeDispatcher{st: st, target: "v2", delayMS: 5_000}
|
||||||
|
w := NewWorker(st, hub, disp, &recAlert{})
|
||||||
|
w.pollPeriod = 50 * time.Millisecond
|
||||||
|
w.hostTimeout = 2 * time.Second
|
||||||
|
if _, err := w.Start(context.Background(), uid, "v2", []string{h1}); err != nil {
|
||||||
|
t.Fatalf("first start: %v", err)
|
||||||
|
}
|
||||||
|
_, err := w.Start(context.Background(), uid, "v2", []string{h2})
|
||||||
|
if !errors.Is(err, store.ErrFleetUpdateRunning) {
|
||||||
|
t.Fatalf("err: %v want ErrFleetUpdateRunning", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
stdhttp "net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateWatcher is the slim view of the ws.updateWatcher this package
|
||||||
|
// uses for tracking in-flight update dispatches. Defined as an
|
||||||
|
// interface so a test can inject a stub.
|
||||||
|
type UpdateWatcher interface {
|
||||||
|
Track(jobID, hostID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FleetWorker is the slim view of the fleetupdate.Worker this package
|
||||||
|
// uses. Kept here for forward compatibility with P6-15 — the host
|
||||||
|
// update endpoint itself does not use it.
|
||||||
|
type FleetWorker interface {
|
||||||
|
Start(ctx context.Context, userID, targetVersion string, hostIDs []string) (string, error)
|
||||||
|
Cancel(ctx context.Context, fleetUpdateID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchHostUpdateResult communicates structured outcomes from the
|
||||||
|
// shared dispatch path so both the HTTP handler and the fleet worker
|
||||||
|
// can format errors in their own idiom.
|
||||||
|
type dispatchHostUpdateResult struct {
|
||||||
|
JobID string
|
||||||
|
Code string // "" on success
|
||||||
|
Status int // HTTP status the JSON handler should use on error
|
||||||
|
Msg string // human-readable detail (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchHostUpdate is the shared "send command.update to one host"
|
||||||
|
// path. It performs every pre-check (host exists, online, version
|
||||||
|
// mismatch, no in-flight update) and on success creates the jobs row,
|
||||||
|
// audits, dispatches the WS envelope, and tracks the watcher entry.
|
||||||
|
//
|
||||||
|
// Pre-checks are returned as structured codes rather than HTTP errors
|
||||||
|
// so the fleet worker can map them onto its own per-host status enum
|
||||||
|
// without parsing strings.
|
||||||
|
func (s *Server) dispatchHostUpdate(ctx context.Context, hostID string, actorKind string, actorID *string) dispatchHostUpdateResult {
|
||||||
|
host, err := s.deps.Store.GetHost(ctx, hostID)
|
||||||
|
if err != nil || host == nil {
|
||||||
|
return dispatchHostUpdateResult{Code: "host_not_found", Status: stdhttp.StatusNotFound}
|
||||||
|
}
|
||||||
|
if !s.deps.Hub.Connected(host.ID) {
|
||||||
|
return dispatchHostUpdateResult{
|
||||||
|
Code: "host_offline", Status: stdhttp.StatusConflict,
|
||||||
|
Msg: "agent is not currently connected",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if host.AgentVersion != "" && host.AgentVersion == version.Version {
|
||||||
|
return dispatchHostUpdateResult{
|
||||||
|
Code: "already_up_to_date", Status: stdhttp.StatusConflict,
|
||||||
|
Msg: "agent already running version " + version.Version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
existing, err := s.deps.Store.RunningUpdateJobForHost(ctx, hostID)
|
||||||
|
if err != nil {
|
||||||
|
return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()}
|
||||||
|
}
|
||||||
|
if existing != "" {
|
||||||
|
return dispatchHostUpdateResult{
|
||||||
|
Code: "update_in_progress", Status: stdhttp.StatusConflict,
|
||||||
|
Msg: "an update job is already in flight for this host",
|
||||||
|
JobID: existing,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jobID := ulid.Make().String()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := s.deps.Store.CreateJob(ctx, store.Job{
|
||||||
|
ID: jobID, HostID: hostID, Kind: "update",
|
||||||
|
ActorKind: actorKind, ActorID: actorID,
|
||||||
|
CreatedAt: now,
|
||||||
|
}); err != nil {
|
||||||
|
return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()}
|
||||||
|
}
|
||||||
|
env, err := api.Marshal(api.MsgCommandUpdate, ulid.Make().String(), api.CommandUpdatePayload{
|
||||||
|
JobID: jobID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return dispatchHostUpdateResult{Code: "internal", Status: stdhttp.StatusInternalServerError, Msg: err.Error()}
|
||||||
|
}
|
||||||
|
if err := s.deps.Hub.Send(ctx, hostID, env); err != nil {
|
||||||
|
// Roll the job to failed so we don't leak a queued row.
|
||||||
|
_ = s.deps.Store.MarkJobFinished(ctx, jobID, "failed", -1, nil, err.Error(), time.Now().UTC())
|
||||||
|
return dispatchHostUpdateResult{
|
||||||
|
Code: "host_offline", Status: stdhttp.StatusConflict, Msg: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.deps.UpdateWatcher != nil {
|
||||||
|
s.deps.UpdateWatcher.Track(jobID, hostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
auditPayload, _ := json.Marshal(map[string]string{
|
||||||
|
"job_id": jobID,
|
||||||
|
"target_version": version.Version,
|
||||||
|
})
|
||||||
|
_ = s.deps.Store.AppendAudit(ctx, store.AuditEntry{
|
||||||
|
ID: ulid.Make().String(),
|
||||||
|
UserID: actorID,
|
||||||
|
Actor: actorKind,
|
||||||
|
Action: "host.update_dispatched",
|
||||||
|
TargetKind: ptr("host"),
|
||||||
|
TargetID: &hostID,
|
||||||
|
TS: now,
|
||||||
|
Payload: auditPayload,
|
||||||
|
})
|
||||||
|
|
||||||
|
return dispatchHostUpdateResult{JobID: jobID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHostUpdate is POST /api/hosts/{id}/update — JSON, admin-only.
|
||||||
|
func (s *Server) handleHostUpdate(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||||
|
user, ok := s.requireUser(r)
|
||||||
|
if !ok {
|
||||||
|
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hostID := chi.URLParam(r, "id")
|
||||||
|
if hostID == "" {
|
||||||
|
writeJSONError(w, stdhttp.StatusBadRequest, "missing_host_id", "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
actor := "user"
|
||||||
|
var actorID *string
|
||||||
|
if user != nil {
|
||||||
|
actorID = &user.ID
|
||||||
|
}
|
||||||
|
res := s.dispatchHostUpdate(r.Context(), hostID, actor, actorID)
|
||||||
|
if res.Code != "" {
|
||||||
|
writeJSONError(w, res.Status, res.Code, res.Msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, stdhttp.StatusAccepted, map[string]string{"job_id": res.JobID})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHostUpdateForm is the HTMX-friendly POST /hosts/{id}/update
|
||||||
|
// variant. On success it sets HX-Redirect to the job detail page; on
|
||||||
|
// pre-check failures it renders an inline error banner.
|
||||||
|
func (s *Server) handleHostUpdateForm(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||||
|
user, ok := s.requireUser(r)
|
||||||
|
if !ok {
|
||||||
|
stdhttp.Error(w, "unauthorised", stdhttp.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hostID := chi.URLParam(r, "id")
|
||||||
|
if hostID == "" {
|
||||||
|
stdhttp.Error(w, "missing host_id", stdhttp.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
actor := "user"
|
||||||
|
var actorID *string
|
||||||
|
if user != nil {
|
||||||
|
actorID = &user.ID
|
||||||
|
}
|
||||||
|
res := s.dispatchHostUpdate(r.Context(), hostID, actor, actorID)
|
||||||
|
if res.Code != "" {
|
||||||
|
// Inline banner for HTMX swaps. Mirrors what host_credentials
|
||||||
|
// returns on validation errors — small text/html fragment.
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(res.Status)
|
||||||
|
msg := hostUpdateErrorMessage(res.Code, res.Msg)
|
||||||
|
_, _ = w.Write([]byte(`<div class="banner banner-error" role="alert">` + htmlEscape(msg) + `</div>`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("HX-Redirect", "/jobs/"+res.JobID)
|
||||||
|
w.WriteHeader(stdhttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostUpdateErrorMessage(code, msg string) string {
|
||||||
|
switch code {
|
||||||
|
case "host_not_found":
|
||||||
|
return "Host not found."
|
||||||
|
case "host_offline":
|
||||||
|
return "Agent is offline; can't deliver the update command."
|
||||||
|
case "already_up_to_date":
|
||||||
|
return "Agent is already running the current version."
|
||||||
|
case "update_in_progress":
|
||||||
|
return "An update is already in progress for this host."
|
||||||
|
}
|
||||||
|
if msg != "" {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
return "Update dispatch failed."
|
||||||
|
}
|
||||||
|
|
||||||
|
// htmlEscape is a minimal HTML-attr-safe escaper. Avoids pulling html/template
|
||||||
|
// for a one-shot inline banner.
|
||||||
|
func htmlEscape(s string) string {
|
||||||
|
out := make([]byte, 0, len(s))
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
|
case '&':
|
||||||
|
out = append(out, []byte("&")...)
|
||||||
|
case '<':
|
||||||
|
out = append(out, []byte("<")...)
|
||||||
|
case '>':
|
||||||
|
out = append(out, []byte(">")...)
|
||||||
|
case '"':
|
||||||
|
out = append(out, []byte(""")...)
|
||||||
|
default:
|
||||||
|
out = append(out, s[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,270 @@
|
|||||||
|
// host_update_test.go — covers POST /api/hosts/{id}/update.
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
stdhttp "net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubWatcher records Track calls so tests can assert the watcher was
|
||||||
|
// notified.
|
||||||
|
type stubWatcher struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tracked []string // hostIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubWatcher) Track(_, hostID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.tracked = append(s.tracked, hostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateHappyPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
srv, ts, st := rawTestServer(t)
|
||||||
|
watcher := &stubWatcher{}
|
||||||
|
srv.deps.UpdateWatcher = watcher
|
||||||
|
hostID, token := enrolHostForWS(t, srv, st, "upd-host")
|
||||||
|
c := agentDial(t, srv, ts, hostID, token)
|
||||||
|
sendHello(t, c, "upd-host")
|
||||||
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
|
// Force a version mismatch so the dispatch isn't short-circuited.
|
||||||
|
if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
||||||
|
t.Fatalf("mark hello: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := loginAsAdmin(t, st)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusAccepted {
|
||||||
|
t.Fatalf("status: got %d, want 202", res.StatusCode)
|
||||||
|
}
|
||||||
|
var out struct {
|
||||||
|
JobID string `json:"job_id"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if out.JobID == "" {
|
||||||
|
t.Fatal("missing job_id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// command.update envelope arrives.
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
var got api.Envelope
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
mt, raw, rerr := c.Read(ctx)
|
||||||
|
cancel()
|
||||||
|
if rerr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if mt != websocket.MessageText {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(raw), `"command.update"`) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = json.Unmarshal(raw, &got)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if got.Type != api.MsgCommandUpdate {
|
||||||
|
t.Fatal("never received command.update envelope")
|
||||||
|
}
|
||||||
|
var cp api.CommandUpdatePayload
|
||||||
|
if err := got.UnmarshalPayload(&cp); err != nil {
|
||||||
|
t.Fatalf("payload: %v", err)
|
||||||
|
}
|
||||||
|
if cp.JobID != out.JobID {
|
||||||
|
t.Fatalf("payload job_id: got %q want %q", cp.JobID, out.JobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watcher tracked.
|
||||||
|
watcher.mu.Lock()
|
||||||
|
defer watcher.mu.Unlock()
|
||||||
|
if len(watcher.tracked) != 1 || watcher.tracked[0] != hostID {
|
||||||
|
t.Fatalf("watcher tracked: %v", watcher.tracked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audit row exists.
|
||||||
|
var n int
|
||||||
|
if err := st.DB().QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM audit_log WHERE action = 'host.update_dispatched' AND target_id = ?`,
|
||||||
|
hostID).Scan(&n); err != nil {
|
||||||
|
t.Fatalf("audit count: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("audit rows: got %d, want 1", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateNotFound(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, ts, st := rawTestServer(t)
|
||||||
|
cookie := loginAsAdmin(t, st)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/no-such/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d want 404", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateOffline(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, ts, st := rawTestServer(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
if err := st.CreateHost(context.Background(), store.Host{
|
||||||
|
ID: hostID, Name: "off", OS: "linux", Arch: "amd64",
|
||||||
|
EnrolledAt: time.Now().UTC(),
|
||||||
|
}, "deadbeef", ""); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
cookie := loginAsAdmin(t, st)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusConflict {
|
||||||
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
||||||
|
}
|
||||||
|
body := readJSONError(t, res.Body)
|
||||||
|
if body.Code != "host_offline" {
|
||||||
|
t.Fatalf("code: %q", body.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateAlreadyUpToDate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
srv, ts, st := rawTestServer(t)
|
||||||
|
hostID, token := enrolHostForWS(t, srv, st, "uptodate-host")
|
||||||
|
c := agentDial(t, srv, ts, hostID, token)
|
||||||
|
sendHello(t, c, "uptodate-host")
|
||||||
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
|
// Force agent_version == version.Version.
|
||||||
|
if err := st.MarkHostHello(context.Background(), hostID, version.Version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
||||||
|
t.Fatalf("mark hello: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := loginAsAdmin(t, st)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusConflict {
|
||||||
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
||||||
|
}
|
||||||
|
body := readJSONError(t, res.Body)
|
||||||
|
if body.Code != "already_up_to_date" {
|
||||||
|
t.Fatalf("code: %q", body.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateInProgress(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
srv, ts, st := rawTestServer(t)
|
||||||
|
hostID, token := enrolHostForWS(t, srv, st, "inprog-host")
|
||||||
|
c := agentDial(t, srv, ts, hostID, token)
|
||||||
|
sendHello(t, c, "inprog-host")
|
||||||
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
||||||
|
t.Fatalf("mark hello: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-seed an in-flight update job.
|
||||||
|
jobID := ulid.Make().String()
|
||||||
|
if err := st.CreateJob(context.Background(), store.Job{
|
||||||
|
ID: jobID, HostID: hostID, Kind: "update",
|
||||||
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed job: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := loginAsAdmin(t, st)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusConflict {
|
||||||
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
||||||
|
}
|
||||||
|
body := readJSONError(t, res.Body)
|
||||||
|
if body.Code != "update_in_progress" {
|
||||||
|
t.Fatalf("code: %q", body.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostUpdateRBAC(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, ts, st := rawTestServer(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
if err := st.CreateHost(context.Background(), store.Host{
|
||||||
|
ID: hostID, Name: "rbac-host", OS: "linux", Arch: "amd64",
|
||||||
|
EnrolledAt: time.Now().UTC(),
|
||||||
|
}, "deadbeef", ""); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
for _, role := range []store.Role{store.RoleViewer, store.RoleOperator} {
|
||||||
|
role := role
|
||||||
|
t.Run(string(role), func(t *testing.T) {
|
||||||
|
cookie := loginAsRole(t, st, role)
|
||||||
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
res, err := stdhttp.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("do: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if res.StatusCode != stdhttp.StatusForbidden {
|
||||||
|
t.Fatalf("status for %s: got %d want 403", role, res.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type jsonErrBody struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func readJSONError(t *testing.T, body io.Reader) jsonErrBody {
|
||||||
|
t.Helper()
|
||||||
|
var out jsonErrBody
|
||||||
|
if err := json.NewDecoder(body).Decode(&out); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -39,6 +39,13 @@ type Deps struct {
|
|||||||
// NotificationHub (optional, wired in G1) is used by the test-fire
|
// NotificationHub (optional, wired in G1) is used by the test-fire
|
||||||
// endpoint to dispatch a single synthetic payload through a channel.
|
// endpoint to dispatch a single synthetic payload through a channel.
|
||||||
NotificationHub *notification.Hub
|
NotificationHub *notification.Hub
|
||||||
|
// UpdateWatcher tracks in-flight agent self-update dispatches and
|
||||||
|
// reconciles them against incoming hello envelopes. Optional;
|
||||||
|
// nil = no-op (handlers degrade by skipping the Track call).
|
||||||
|
UpdateWatcher UpdateWatcher
|
||||||
|
// FleetWorker drives the rolling fleet-update worker. Optional;
|
||||||
|
// nil = fleet update endpoints (P6-15) report unavailable.
|
||||||
|
FleetWorker FleetWorker
|
||||||
// Version is the binary's build version, surfaced in the chrome.
|
// Version is the binary's build version, surfaced in the chrome.
|
||||||
// Empty falls back to "dev".
|
// Empty falls back to "dev".
|
||||||
Version string
|
Version string
|
||||||
@@ -125,7 +132,7 @@ func (s *Server) routes(r chi.Router) {
|
|||||||
r.Get("/install/*", s.handleInstallAsset)
|
r.Get("/install/*", s.handleInstallAsset)
|
||||||
r.Get("/api/version", s.handleVersion)
|
r.Get("/api/version", s.handleVersion)
|
||||||
if s.deps.Hub != nil {
|
if s.deps.Hub != nil {
|
||||||
r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{
|
hd := ws.HandlerDeps{
|
||||||
Hub: s.deps.Hub,
|
Hub: s.deps.Hub,
|
||||||
Store: s.deps.Store,
|
Store: s.deps.Store,
|
||||||
JobHub: s.deps.JobHub,
|
JobHub: s.deps.JobHub,
|
||||||
@@ -133,7 +140,11 @@ func (s *Server) routes(r chi.Router) {
|
|||||||
OnHello: s.onAgentHello,
|
OnHello: s.onAgentHello,
|
||||||
OnScheduleAck: s.applyScheduleAck,
|
OnScheduleAck: s.applyScheduleAck,
|
||||||
OnScheduleFire: s.dispatchScheduledJob,
|
OnScheduleFire: s.dispatchScheduledJob,
|
||||||
}))
|
}
|
||||||
|
if w, ok := s.deps.UpdateWatcher.(*ws.UpdateWatcher); ok && w != nil {
|
||||||
|
hd.UpdateWatcher = w
|
||||||
|
}
|
||||||
|
r.Mount("/ws/agent", ws.AgentHandler(hd))
|
||||||
}
|
}
|
||||||
r.Get("/ws/agent/pending", s.handlePendingWS)
|
r.Get("/ws/agent/pending", s.handlePendingWS)
|
||||||
r.Mount("/static/", staticHandler())
|
r.Mount("/static/", staticHandler())
|
||||||
@@ -271,6 +282,9 @@ func (s *Server) routes(r chi.Router) {
|
|||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(s.requireRole(store.RoleAdmin))
|
r.Use(s.requireRole(store.RoleAdmin))
|
||||||
|
|
||||||
|
r.Post("/api/hosts/{id}/update", s.handleHostUpdate)
|
||||||
|
r.Post("/hosts/{id}/update", s.handleHostUpdateForm)
|
||||||
|
|
||||||
r.Get("/api/users", s.handleAPIUsersList)
|
r.Get("/api/users", s.handleAPIUsersList)
|
||||||
r.Post("/api/users", s.handleAPIUserCreate)
|
r.Post("/api/users", s.handleAPIUserCreate)
|
||||||
r.Get("/api/users/{id}", s.handleAPIUserGet)
|
r.Get("/api/users/{id}", s.handleAPIUserGet)
|
||||||
@@ -322,6 +336,27 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
|||||||
return s.srv.Shutdown(ctx)
|
return s.srv.Shutdown(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFleetWorker installs the fleet-update worker post-construction.
|
||||||
|
// Used to break the wiring loop in cmd/server (the worker depends on a
|
||||||
|
// dispatcher that delegates back into the server's host-update path).
|
||||||
|
func (s *Server) SetFleetWorker(fw FleetWorker) { s.deps.FleetWorker = fw }
|
||||||
|
|
||||||
|
// DispatchHostUpdate is the public entry point for callers (the fleet
|
||||||
|
// worker) that need to drive the same dispatch path the HTTP handler
|
||||||
|
// uses, without going through HTTP. Returns the structured result so
|
||||||
|
// the caller can map error codes to its own status enum.
|
||||||
|
func (s *Server) DispatchHostUpdate(ctx context.Context, hostID, actorUserID string) (jobID string, code string, err error) {
|
||||||
|
var actorID *string
|
||||||
|
if actorUserID != "" {
|
||||||
|
actorID = &actorUserID
|
||||||
|
}
|
||||||
|
res := s.dispatchHostUpdate(ctx, hostID, "user", actorID)
|
||||||
|
if res.Code != "" {
|
||||||
|
return res.JobID, res.Code, nil
|
||||||
|
}
|
||||||
|
return res.JobID, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
// Addr returns the configured listen address. Useful in tests when
|
// Addr returns the configured listen address. Useful in tests when
|
||||||
// the caller passes :0 to get a random port.
|
// the caller passes :0 to get a random port.
|
||||||
func (s *Server) Addr() string { return s.srv.Addr }
|
func (s *Server) Addr() string { return s.srv.Addr }
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
||||||
@@ -26,6 +27,9 @@ type HandlerDeps struct {
|
|||||||
// AlertEngine receives job-finished and host-online events so the
|
// AlertEngine receives job-finished and host-online events so the
|
||||||
// alert engine can evaluate its rules. Optional; nil = no-op.
|
// alert engine can evaluate its rules. Optional; nil = no-op.
|
||||||
AlertEngine *alert.Engine
|
AlertEngine *alert.Engine
|
||||||
|
// UpdateWatcher reconciles in-flight agent-update dispatches against
|
||||||
|
// hello envelopes. Optional; nil = no-op.
|
||||||
|
UpdateWatcher *UpdateWatcher
|
||||||
// OnHello is called once per successful hello, after the host row
|
// OnHello is called once per successful hello, after the host row
|
||||||
// has been touched and the conn registered. Used by the HTTP
|
// has been touched and the conn registered. Used by the HTTP
|
||||||
// layer to push host_credentials down as a config.update before
|
// layer to push host_credentials down as a config.update before
|
||||||
@@ -147,6 +151,9 @@ func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps)
|
|||||||
if deps.AlertEngine != nil {
|
if deps.AlertEngine != nil {
|
||||||
deps.AlertEngine.NotifyHostOnline(hostID)
|
deps.AlertEngine.NotifyHostOnline(hostID)
|
||||||
}
|
}
|
||||||
|
if deps.UpdateWatcher != nil {
|
||||||
|
deps.UpdateWatcher.OnHello(ctx, hostID, helloPayload.AgentVersion, version.Version)
|
||||||
|
}
|
||||||
|
|
||||||
deps.Hub.Register(hostID, c)
|
deps.Hub.Register(hostID, c)
|
||||||
defer deps.Hub.Unregister(hostID, c)
|
defer deps.Hub.Unregister(hostID, c)
|
||||||
|
|||||||
@@ -0,0 +1,151 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// updateTimeout bounds how long the watcher waits for an agent to come
|
||||||
|
// back with its new version after a command.update dispatch. var (not
|
||||||
|
// const) so tests can shrink it.
|
||||||
|
var updateTimeout = 90 * time.Second
|
||||||
|
|
||||||
|
// AlertRaiser is the slim subset of *alert.Engine the update watcher
|
||||||
|
// touches. Defined here (not in the alert package) so the dependency
|
||||||
|
// arrow points the right way.
|
||||||
|
type AlertRaiser interface {
|
||||||
|
RaiseUpdateFailed(ctx context.Context, hostID, jobID, reason string, when time.Time)
|
||||||
|
ResolveUpdateFailed(ctx context.Context, hostID string, when time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWatcher tracks in-flight agent-update dispatches and reconciles
|
||||||
|
// them against incoming hello envelopes. Entries land on Track and
|
||||||
|
// resolve via OnHello (success path) or the periodic sweep (timeout).
|
||||||
|
type UpdateWatcher struct {
|
||||||
|
store *store.Store
|
||||||
|
alerts AlertRaiser
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
entries map[string]*updateEntry // hostID → entry
|
||||||
|
|
||||||
|
tickPeriod time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateEntry struct {
|
||||||
|
jobID string
|
||||||
|
startedAt time.Time
|
||||||
|
// terminated is set once the entry has reached a terminal state so
|
||||||
|
// late OnHellos don't resurrect it.
|
||||||
|
terminated bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpdateWatcher builds an unstarted watcher. Call Run in a goroutine
|
||||||
|
// to start the periodic sweep.
|
||||||
|
func NewUpdateWatcher(st *store.Store, alerts AlertRaiser) *UpdateWatcher {
|
||||||
|
return &UpdateWatcher{
|
||||||
|
store: st,
|
||||||
|
alerts: alerts,
|
||||||
|
entries: make(map[string]*updateEntry),
|
||||||
|
tickPeriod: 5 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track registers a freshly-dispatched update job. A subsequent Track
|
||||||
|
// for the same host replaces the prior entry (last-write-wins).
|
||||||
|
func (w *UpdateWatcher) Track(jobID, hostID string) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.mu.Lock()
|
||||||
|
w.entries[hostID] = &updateEntry{jobID: jobID, startedAt: time.Now()}
|
||||||
|
w.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnHello is called by the WS handler after a successful hello has been
|
||||||
|
// persisted. If a tracked update for the host matches the targetVersion,
|
||||||
|
// the job is marked succeeded and any open update_failed alert is
|
||||||
|
// auto-resolved. A non-matching version is a no-op (the watcher keeps
|
||||||
|
// waiting until the timeout).
|
||||||
|
func (w *UpdateWatcher) OnHello(ctx context.Context, hostID, agentVersion, targetVersion string) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.mu.Lock()
|
||||||
|
e, ok := w.entries[hostID]
|
||||||
|
if !ok || e.terminated {
|
||||||
|
w.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if agentVersion != targetVersion {
|
||||||
|
// Not the version we asked for — keep waiting.
|
||||||
|
w.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.terminated = true
|
||||||
|
jobID := e.jobID
|
||||||
|
delete(w.entries, hostID)
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := w.store.MarkJobFinished(ctx, jobID, "succeeded", 0, nil, "", now); err != nil {
|
||||||
|
slog.Warn("ws update watcher: mark succeeded", "job_id", jobID, "host_id", hostID, "err", err)
|
||||||
|
}
|
||||||
|
if w.alerts != nil {
|
||||||
|
w.alerts.ResolveUpdateFailed(ctx, hostID, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run drives the periodic sweep. Returns when ctx is done.
|
||||||
|
func (w *UpdateWatcher) Run(ctx context.Context) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t := time.NewTicker(w.tickPeriod)
|
||||||
|
defer t.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case now := <-t.C:
|
||||||
|
w.sweep(ctx, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *UpdateWatcher) sweep(ctx context.Context, now time.Time) {
|
||||||
|
type expired struct {
|
||||||
|
hostID string
|
||||||
|
jobID string
|
||||||
|
age time.Duration
|
||||||
|
}
|
||||||
|
var toFail []expired
|
||||||
|
w.mu.Lock()
|
||||||
|
for hostID, e := range w.entries {
|
||||||
|
if e.terminated {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if now.Sub(e.startedAt) >= updateTimeout {
|
||||||
|
toFail = append(toFail, expired{hostID: hostID, jobID: e.jobID, age: now.Sub(e.startedAt)})
|
||||||
|
e.terminated = true
|
||||||
|
delete(w.entries, hostID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
for _, x := range toFail {
|
||||||
|
reason := fmt.Sprintf("timeout: agent did not reconnect within %s", updateTimeout)
|
||||||
|
stamp := now.UTC()
|
||||||
|
errMsg := reason
|
||||||
|
if err := w.store.MarkJobFinished(ctx, x.jobID, "failed", -1, nil, errMsg, stamp); err != nil {
|
||||||
|
slog.Warn("ws update watcher: mark failed", "job_id", x.jobID, "host_id", x.hostID, "err", err)
|
||||||
|
}
|
||||||
|
if w.alerts != nil {
|
||||||
|
w.alerts.RaiseUpdateFailed(ctx, x.hostID, x.jobID, reason, stamp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
|
||||||
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeAlerts struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
raised []string // hostIDs
|
||||||
|
resolved []string
|
||||||
|
reasons []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeAlerts) RaiseUpdateFailed(_ context.Context, hostID, _ /*jobID*/, reason string, _ time.Time) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.raised = append(f.raised, hostID)
|
||||||
|
f.reasons = append(f.reasons, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeAlerts) ResolveUpdateFailed(_ context.Context, hostID string, _ time.Time) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.resolved = append(f.resolved, hostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedJob(t *testing.T, st *store.Store, hostID string) string {
|
||||||
|
t.Helper()
|
||||||
|
jobID := ulid.Make().String()
|
||||||
|
if err := st.CreateJob(context.Background(), store.Job{
|
||||||
|
ID: jobID, HostID: hostID, Kind: "update",
|
||||||
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create job: %v", err)
|
||||||
|
}
|
||||||
|
return jobID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateWatcherOnHelloSuccess(t *testing.T) {
|
||||||
|
st := openWSTestStore(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
seedHostWS(t, st, hostID)
|
||||||
|
jobID := seedJob(t, st, hostID)
|
||||||
|
|
||||||
|
a := &fakeAlerts{}
|
||||||
|
w := NewUpdateWatcher(st, a)
|
||||||
|
w.Track(jobID, hostID)
|
||||||
|
|
||||||
|
w.OnHello(context.Background(), hostID, "v2", "v2")
|
||||||
|
|
||||||
|
job, err := st.GetJob(context.Background(), jobID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get job: %v", err)
|
||||||
|
}
|
||||||
|
if job.Status != "succeeded" {
|
||||||
|
t.Fatalf("status: got %q want succeeded", job.Status)
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
if len(a.resolved) != 1 || a.resolved[0] != hostID {
|
||||||
|
t.Fatalf("resolve calls: %v", a.resolved)
|
||||||
|
}
|
||||||
|
if len(a.raised) != 0 {
|
||||||
|
t.Fatalf("unexpected raises: %v", a.raised)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateWatcherTimeout(t *testing.T) {
|
||||||
|
prev := updateTimeout
|
||||||
|
updateTimeout = 50 * time.Millisecond
|
||||||
|
t.Cleanup(func() { updateTimeout = prev })
|
||||||
|
|
||||||
|
st := openWSTestStore(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
seedHostWS(t, st, hostID)
|
||||||
|
jobID := seedJob(t, st, hostID)
|
||||||
|
|
||||||
|
a := &fakeAlerts{}
|
||||||
|
w := NewUpdateWatcher(st, a)
|
||||||
|
w.Track(jobID, hostID)
|
||||||
|
|
||||||
|
time.Sleep(80 * time.Millisecond)
|
||||||
|
w.sweep(context.Background(), time.Now())
|
||||||
|
|
||||||
|
job, err := st.GetJob(context.Background(), jobID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get job: %v", err)
|
||||||
|
}
|
||||||
|
if job.Status != "failed" {
|
||||||
|
t.Fatalf("status: got %q want failed", job.Status)
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
if len(a.raised) != 1 || a.raised[0] != hostID {
|
||||||
|
t.Fatalf("raise calls: %v", a.raised)
|
||||||
|
}
|
||||||
|
if len(a.reasons) == 0 || a.reasons[0] == "" {
|
||||||
|
t.Fatalf("missing reason")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateWatcherMismatchedVersionNoOp(t *testing.T) {
|
||||||
|
st := openWSTestStore(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
seedHostWS(t, st, hostID)
|
||||||
|
jobID := seedJob(t, st, hostID)
|
||||||
|
|
||||||
|
a := &fakeAlerts{}
|
||||||
|
w := NewUpdateWatcher(st, a)
|
||||||
|
w.Track(jobID, hostID)
|
||||||
|
|
||||||
|
w.OnHello(context.Background(), hostID, "v1", "v2")
|
||||||
|
|
||||||
|
job, _ := st.GetJob(context.Background(), jobID)
|
||||||
|
if job.Status == "succeeded" || job.Status == "failed" {
|
||||||
|
t.Fatalf("status flipped on mismatched hello: %q", job.Status)
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
if len(a.raised) != 0 || len(a.resolved) != 0 {
|
||||||
|
t.Fatalf("unexpected alert calls raised=%v resolved=%v", a.raised, a.resolved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateWatcherHelloAfterTimeoutIsNoOp(t *testing.T) {
|
||||||
|
prev := updateTimeout
|
||||||
|
updateTimeout = 50 * time.Millisecond
|
||||||
|
t.Cleanup(func() { updateTimeout = prev })
|
||||||
|
|
||||||
|
st := openWSTestStore(t)
|
||||||
|
hostID := ulid.Make().String()
|
||||||
|
seedHostWS(t, st, hostID)
|
||||||
|
jobID := seedJob(t, st, hostID)
|
||||||
|
|
||||||
|
a := &fakeAlerts{}
|
||||||
|
w := NewUpdateWatcher(st, a)
|
||||||
|
w.Track(jobID, hostID)
|
||||||
|
|
||||||
|
time.Sleep(80 * time.Millisecond)
|
||||||
|
w.sweep(context.Background(), time.Now())
|
||||||
|
|
||||||
|
// Hello arrives after sweep — entry already gone, must be no-op.
|
||||||
|
w.OnHello(context.Background(), hostID, "v2", "v2")
|
||||||
|
|
||||||
|
job, _ := st.GetJob(context.Background(), jobID)
|
||||||
|
if job.Status != "failed" {
|
||||||
|
t.Fatalf("status flipped from failed → %q", job.Status)
|
||||||
|
}
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
if len(a.resolved) != 0 {
|
||||||
|
t.Fatalf("late hello triggered ResolveUpdateFailed: %v", a.resolved)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -77,6 +77,56 @@ func (s *Store) RaiseOrTouch(ctx context.Context, hostID, kind, dedupKey, severi
|
|||||||
return id, true, nil
|
return id, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RaiseOrTouchSystem is the host-less variant of RaiseOrTouch — the
|
||||||
|
// alert row's host_id is stored as NULL, so the FK to hosts is bypassed.
|
||||||
|
// Used by fleet-wide alerts (e.g. fleet_update_halted) where the
|
||||||
|
// failure surface isn't pinned to a single host.
|
||||||
|
func (s *Store) RaiseOrTouchSystem(ctx context.Context, kind, dedupKey, severity, message string, when time.Time) (id string, didRaise bool, err error) {
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, fmt.Errorf("store: begin: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
row := tx.QueryRowContext(ctx,
|
||||||
|
`SELECT id FROM alerts
|
||||||
|
WHERE host_id IS NULL AND kind = ? AND dedup_key = ? AND resolved_at IS NULL
|
||||||
|
LIMIT 1`,
|
||||||
|
kind, dedupKey)
|
||||||
|
var existing string
|
||||||
|
switch err := row.Scan(&existing); {
|
||||||
|
case err == nil:
|
||||||
|
_, uerr := tx.ExecContext(ctx,
|
||||||
|
`UPDATE alerts SET last_seen_at = ?, message = ? WHERE id = ?`,
|
||||||
|
when.UTC().Format(time.RFC3339Nano), message, existing)
|
||||||
|
if uerr != nil {
|
||||||
|
return "", false, fmt.Errorf("store: touch alert: %w", uerr)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
return existing, false, nil
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
// fall through to insert
|
||||||
|
default:
|
||||||
|
return "", false, fmt.Errorf("store: lookup alert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id = ulid.Make().String()
|
||||||
|
whenStr := when.UTC().Format(time.RFC3339Nano)
|
||||||
|
_, err = tx.ExecContext(ctx,
|
||||||
|
`INSERT INTO alerts (id, host_id, kind, dedup_key, severity, message, created_at, last_seen_at)
|
||||||
|
VALUES (?, NULL, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
id, kind, dedupKey, severity, message, whenStr, whenStr)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, fmt.Errorf("store: insert alert: %w", err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
return id, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// refreshHostOpenAlertCount recomputes hosts.open_alert_count from the
|
// refreshHostOpenAlertCount recomputes hosts.open_alert_count from the
|
||||||
// alerts table for one host. Self-healing: idempotent and survives
|
// alerts table for one host. Self-healing: idempotent and survives
|
||||||
// out-of-order edits. Best-effort — errors are returned but callers
|
// out-of-order edits. Best-effort — errors are returned but callers
|
||||||
|
|||||||
Reference in New Issue
Block a user