// host_update_test.go — covers POST /api/hosts/{id}/update. package http import ( "context" "encoding/json" "io" stdhttp "net/http" "strings" "sync" "testing" "time" "github.com/coder/websocket" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" "gitea.dcglab.co.uk/steve/restic-manager/internal/version" ) // stubWatcher records Track calls so tests can assert the watcher was // notified. type stubWatcher struct { mu sync.Mutex tracked []string // hostIDs } func (s *stubWatcher) Track(_, hostID string) { s.mu.Lock() defer s.mu.Unlock() s.tracked = append(s.tracked, hostID) } func TestHostUpdateHappyPath(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) watcher := &stubWatcher{} srv.deps.UpdateWatcher = watcher hostID, token := enrolHostForWS(t, srv, st, "upd-host") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "upd-host") _ = drainUntil(t, c, api.MsgScheduleSet) // Force a version mismatch so the dispatch isn't short-circuited. if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("mark hello: %v", err) } cookie := loginAsAdmin(t, st) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusAccepted { t.Fatalf("status: got %d, want 202", res.StatusCode) } var out struct { JobID string `json:"job_id"` } if err := json.NewDecoder(res.Body).Decode(&out); err != nil { t.Fatalf("decode: %v", err) } if out.JobID == "" { t.Fatal("missing job_id in response") } // command.update envelope arrives. deadline := time.Now().Add(2 * time.Second) var got api.Envelope for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) mt, raw, rerr := c.Read(ctx) cancel() if rerr != nil { break } if mt != websocket.MessageText { continue } if !strings.Contains(string(raw), `"command.update"`) { continue } _ = json.Unmarshal(raw, &got) break } if got.Type != api.MsgCommandUpdate { t.Fatal("never received command.update envelope") } var cp api.CommandUpdatePayload if err := got.UnmarshalPayload(&cp); err != nil { t.Fatalf("payload: %v", err) } if cp.JobID != out.JobID { t.Fatalf("payload job_id: got %q want %q", cp.JobID, out.JobID) } // Watcher tracked. watcher.mu.Lock() defer watcher.mu.Unlock() if len(watcher.tracked) != 1 || watcher.tracked[0] != hostID { t.Fatalf("watcher tracked: %v", watcher.tracked) } // Audit row exists. var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM audit_log WHERE action = 'host.update_dispatched' AND target_id = ?`, hostID).Scan(&n); err != nil { t.Fatalf("audit count: %v", err) } if n != 1 { t.Fatalf("audit rows: got %d, want 1", n) } } func TestHostUpdateNotFound(t *testing.T) { t.Parallel() _, ts, st := rawTestServer(t) cookie := loginAsAdmin(t, st) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/no-such/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusNotFound { t.Fatalf("status: got %d want 404", res.StatusCode) } } func TestHostUpdateOffline(t *testing.T) { t.Parallel() _, ts, st := rawTestServer(t) hostID := ulid.Make().String() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: "off", OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, "deadbeef", ""); err != nil { t.Fatalf("create: %v", err) } cookie := loginAsAdmin(t, st) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusConflict { t.Fatalf("status: got %d want 409", res.StatusCode) } body := readJSONError(t, res.Body) if body.Code != "host_offline" { t.Fatalf("code: %q", body.Code) } } func TestHostUpdateAlreadyUpToDate(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "uptodate-host") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "uptodate-host") _ = drainUntil(t, c, api.MsgScheduleSet) // Force agent_version == version.Version. if err := st.MarkHostHello(context.Background(), hostID, version.Version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("mark hello: %v", err) } cookie := loginAsAdmin(t, st) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusConflict { t.Fatalf("status: got %d want 409", res.StatusCode) } body := readJSONError(t, res.Body) if body.Code != "already_up_to_date" { t.Fatalf("code: %q", body.Code) } } func TestHostUpdateInProgress(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "inprog-host") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "inprog-host") _ = drainUntil(t, c, api.MsgScheduleSet) if err := st.MarkHostHello(context.Background(), hostID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("mark hello: %v", err) } // Pre-seed an in-flight update job. jobID := ulid.Make().String() if err := st.CreateJob(context.Background(), store.Job{ ID: jobID, HostID: hostID, Kind: "update", ActorKind: "user", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed job: %v", err) } cookie := loginAsAdmin(t, st) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusConflict { t.Fatalf("status: got %d want 409", res.StatusCode) } body := readJSONError(t, res.Body) if body.Code != "update_in_progress" { t.Fatalf("code: %q", body.Code) } } func TestHostUpdateRBAC(t *testing.T) { t.Parallel() _, ts, st := rawTestServer(t) hostID := ulid.Make().String() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: "rbac-host", OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, "deadbeef", ""); err != nil { t.Fatalf("create: %v", err) } for _, role := range []store.Role{store.RoleViewer, store.RoleOperator} { role := role t.Run(string(role), func(t *testing.T) { cookie := loginAsRole(t, st, role) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/hosts/"+hostID+"/update", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusForbidden { t.Fatalf("status for %s: got %d want 403", role, res.StatusCode) } }) } } type jsonErrBody struct { Code string `json:"code"` Message string `json:"message,omitempty"` } func readJSONError(t *testing.T, body io.Reader) jsonErrBody { t.Helper() var out jsonErrBody if err := json.NewDecoder(body).Decode(&out); err != nil { t.Fatalf("decode error body: %v", err) } return out }