9d5775fb47
- alert: update_failed (per-host, dedup=hostID) + fleet_update_halted
(system-scoped, host_id NULL via new RaiseOrTouchSystem helper).
- ws: UpdateWatcher tracks in-flight command.update dispatches and
reconciles them against incoming hello envelopes — success path
marks the job succeeded and auto-resolves the alert; 90s timeout
marks the job failed and raises update_failed.
- http: POST /api/hosts/{id}/update (admin-only JSON) + the HTMX
/hosts/{id}/update form variant. Pre-checks: host exists, online,
agent_version != current, no running update job. Refactored core
into Server.dispatchHostUpdate so the fleet worker can share it
without going through HTTP.
- fleetupdate: rolling worker iterating through host slots, halting
on first failure and raising fleet_update_halted. Polling-based
version-match (re-read hosts.agent_version every 1s up to 95s) —
no extra plumbing into the WS hello path. At-most-one-running is
enforced at the store layer (ErrFleetUpdateRunning).
- cmd/server: wire UpdateWatcher and FleetWorker into the main
goroutine; the worker uses a small serverDispatcher adapter that
delegates back into Server.DispatchHostUpdate.
Tests: watcher (success/timeout/mismatch/late-hello), HTTP endpoint
(happy + four pre-check branches + RBAC), worker (two-host happy,
timeout-halt, host-offline-halt, already-at-target skip, cancel
mid-run, double-Start guard).
271 lines
7.5 KiB
Go
271 lines
7.5 KiB
Go
// host_update_test.go — covers POST /api/hosts/{id}/update.
|
|
package http
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
stdhttp "net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/oklog/ulid/v2"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/version"
|
|
)
|
|
|
|
// stubWatcher records Track calls so tests can assert the watcher was
|
|
// notified.
|
|
type stubWatcher struct {
|
|
mu sync.Mutex
|
|
tracked []string // hostIDs
|
|
}
|
|
|
|
func (s *stubWatcher) Track(_, hostID string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.tracked = append(s.tracked, hostID)
|
|
}
|
|
|
|
func TestHostUpdateHappyPath(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, st := rawTestServer(t)
|
|
watcher := &stubWatcher{}
|
|
srv.deps.UpdateWatcher = watcher
|
|
hostID, token := enrolHostForWS(t, srv, st, "upd-host")
|
|
c := agentDial(t, srv, ts, hostID, token)
|
|
sendHello(t, c, "upd-host")
|
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
|
|
|
// Force a version mismatch so the dispatch isn't short-circuited.
|
|
if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
|
t.Fatalf("mark hello: %v", err)
|
|
}
|
|
|
|
cookie := loginAsAdmin(t, st)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusAccepted {
|
|
t.Fatalf("status: got %d, want 202", res.StatusCode)
|
|
}
|
|
var out struct {
|
|
JobID string `json:"job_id"`
|
|
}
|
|
if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
if out.JobID == "" {
|
|
t.Fatal("missing job_id in response")
|
|
}
|
|
|
|
// command.update envelope arrives.
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
var got api.Envelope
|
|
for time.Now().Before(deadline) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
|
mt, raw, rerr := c.Read(ctx)
|
|
cancel()
|
|
if rerr != nil {
|
|
break
|
|
}
|
|
if mt != websocket.MessageText {
|
|
continue
|
|
}
|
|
if !strings.Contains(string(raw), `"command.update"`) {
|
|
continue
|
|
}
|
|
_ = json.Unmarshal(raw, &got)
|
|
break
|
|
}
|
|
if got.Type != api.MsgCommandUpdate {
|
|
t.Fatal("never received command.update envelope")
|
|
}
|
|
var cp api.CommandUpdatePayload
|
|
if err := got.UnmarshalPayload(&cp); err != nil {
|
|
t.Fatalf("payload: %v", err)
|
|
}
|
|
if cp.JobID != out.JobID {
|
|
t.Fatalf("payload job_id: got %q want %q", cp.JobID, out.JobID)
|
|
}
|
|
|
|
// Watcher tracked.
|
|
watcher.mu.Lock()
|
|
defer watcher.mu.Unlock()
|
|
if len(watcher.tracked) != 1 || watcher.tracked[0] != hostID {
|
|
t.Fatalf("watcher tracked: %v", watcher.tracked)
|
|
}
|
|
|
|
// Audit row exists.
|
|
var n int
|
|
if err := st.DB().QueryRow(
|
|
`SELECT COUNT(*) FROM audit_log WHERE action = 'host.update_dispatched' AND target_id = ?`,
|
|
hostID).Scan(&n); err != nil {
|
|
t.Fatalf("audit count: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Fatalf("audit rows: got %d, want 1", n)
|
|
}
|
|
}
|
|
|
|
func TestHostUpdateNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
_, ts, st := rawTestServer(t)
|
|
cookie := loginAsAdmin(t, st)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/no-such/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusNotFound {
|
|
t.Fatalf("status: got %d want 404", res.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHostUpdateOffline(t *testing.T) {
|
|
t.Parallel()
|
|
_, ts, st := rawTestServer(t)
|
|
hostID := ulid.Make().String()
|
|
if err := st.CreateHost(context.Background(), store.Host{
|
|
ID: hostID, Name: "off", OS: "linux", Arch: "amd64",
|
|
EnrolledAt: time.Now().UTC(),
|
|
}, "deadbeef", ""); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
cookie := loginAsAdmin(t, st)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusConflict {
|
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
|
}
|
|
body := readJSONError(t, res.Body)
|
|
if body.Code != "host_offline" {
|
|
t.Fatalf("code: %q", body.Code)
|
|
}
|
|
}
|
|
|
|
func TestHostUpdateAlreadyUpToDate(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, st := rawTestServer(t)
|
|
hostID, token := enrolHostForWS(t, srv, st, "uptodate-host")
|
|
c := agentDial(t, srv, ts, hostID, token)
|
|
sendHello(t, c, "uptodate-host")
|
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
|
|
|
// Force agent_version == version.Version.
|
|
if err := st.MarkHostHello(context.Background(), hostID, version.Version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
|
t.Fatalf("mark hello: %v", err)
|
|
}
|
|
|
|
cookie := loginAsAdmin(t, st)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusConflict {
|
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
|
}
|
|
body := readJSONError(t, res.Body)
|
|
if body.Code != "already_up_to_date" {
|
|
t.Fatalf("code: %q", body.Code)
|
|
}
|
|
}
|
|
|
|
func TestHostUpdateInProgress(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, st := rawTestServer(t)
|
|
hostID, token := enrolHostForWS(t, srv, st, "inprog-host")
|
|
c := agentDial(t, srv, ts, hostID, token)
|
|
sendHello(t, c, "inprog-host")
|
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
|
if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil {
|
|
t.Fatalf("mark hello: %v", err)
|
|
}
|
|
|
|
// Pre-seed an in-flight update job.
|
|
jobID := ulid.Make().String()
|
|
if err := st.CreateJob(context.Background(), store.Job{
|
|
ID: jobID, HostID: hostID, Kind: "update",
|
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
|
}); err != nil {
|
|
t.Fatalf("seed job: %v", err)
|
|
}
|
|
|
|
cookie := loginAsAdmin(t, st)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusConflict {
|
|
t.Fatalf("status: got %d want 409", res.StatusCode)
|
|
}
|
|
body := readJSONError(t, res.Body)
|
|
if body.Code != "update_in_progress" {
|
|
t.Fatalf("code: %q", body.Code)
|
|
}
|
|
}
|
|
|
|
func TestHostUpdateRBAC(t *testing.T) {
|
|
t.Parallel()
|
|
_, ts, st := rawTestServer(t)
|
|
hostID := ulid.Make().String()
|
|
if err := st.CreateHost(context.Background(), store.Host{
|
|
ID: hostID, Name: "rbac-host", OS: "linux", Arch: "amd64",
|
|
EnrolledAt: time.Now().UTC(),
|
|
}, "deadbeef", ""); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
for _, role := range []store.Role{store.RoleViewer, store.RoleOperator} {
|
|
role := role
|
|
t.Run(string(role), func(t *testing.T) {
|
|
cookie := loginAsRole(t, st, role)
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil)
|
|
req.AddCookie(cookie)
|
|
res, err := stdhttp.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("do: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusForbidden {
|
|
t.Fatalf("status for %s: got %d want 403", role, res.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type jsonErrBody struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message,omitempty"`
|
|
}
|
|
|
|
func readJSONError(t *testing.T, body io.Reader) jsonErrBody {
|
|
t.Helper()
|
|
var out jsonErrBody
|
|
if err := json.NewDecoder(body).Decode(&out); err != nil {
|
|
t.Fatalf("decode error body: %v", err)
|
|
}
|
|
return out
|
|
}
|