// fleet_update_test.go — coverage for the P6-15 fleet-update HTTP // surface: start/cancel/get JSON endpoints + RBAC. package http import ( "bytes" "context" "encoding/json" stdhttp "net/http" "sync" "testing" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" "gitea.dcglab.co.uk/steve/restic-manager/internal/version" ) // fakeFleetWorker stands in for *fleetupdate.Worker in HTTP tests. // It records what was passed to Start/Cancel and lets tests inject // canned errors. Satisfies the FleetWorker interface in // host_update.go. type fakeFleetWorker struct { mu sync.Mutex startCalls []fakeStartCall startID string startErr error cancelCalls []string cancelErr error } type fakeStartCall struct { UserID string Target string HostIDs []string } func (f *fakeFleetWorker) Start(_ context.Context, userID, target string, hostIDs []string) (string, error) { f.mu.Lock() defer f.mu.Unlock() f.startCalls = append(f.startCalls, fakeStartCall{userID, target, append([]string(nil), hostIDs...)}) if f.startErr != nil { return "", f.startErr } return f.startID, nil } func (f *fakeFleetWorker) Cancel(_ context.Context, id string) error { f.mu.Lock() defer f.mu.Unlock() f.cancelCalls = append(f.cancelCalls, id) return f.cancelErr } // helloOnlineHost is the smallest setup that lets the dispatch / // derivation logic see a host as "online + version mismatch". // Returns the host id. func helloOnlineHost(t *testing.T, srv *Server, st *store.Store, name, agentVer string) string { t.Helper() id := makeHost(t, st, name) if err := st.MarkHostHello(context.Background(), id, agentVer, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("mark hello: %v", err) } // Mark connected on the hub so deriveOutOfDateOnlineHostIDs // considers it online without needing a real WS handshake. The // Conn has a nil websocket pointer — tests never call Send on it. srv.deps.Hub.Register(id, ws.NewConn(id, nil)) return id } func TestFleetUpdateStartHappyPath(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) worker := &fakeFleetWorker{startID: ulid.Make().String()} srv.deps.FleetWorker = worker cookie, uid := loginAsAdminWithID(t, st) hostID := helloOnlineHost(t, srv, st, "fu-host", "v0") body := map[string]any{"host_ids": []string{hostID}} raw, _ := json.Marshal(body) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet/update", bytes.NewReader(raw)) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusAccepted { t.Fatalf("status: got %d, want 202", res.StatusCode) } var out struct { FleetUpdateID string `json:"fleet_update_id"` } if err := json.NewDecoder(res.Body).Decode(&out); err != nil { t.Fatalf("decode: %v", err) } if out.FleetUpdateID != worker.startID { t.Fatalf("fleet_update_id: got %q, want %q", out.FleetUpdateID, worker.startID) } worker.mu.Lock() if len(worker.startCalls) != 1 || worker.startCalls[0].UserID != uid { t.Fatalf("start calls: %+v", worker.startCalls) } if got := worker.startCalls[0].HostIDs; len(got) != 1 || got[0] != hostID { t.Fatalf("host_ids: %v", got) } worker.mu.Unlock() // Audit row. var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM audit_log WHERE action = 'fleet.update_started' AND target_id = ?`, out.FleetUpdateID).Scan(&n); err != nil { t.Fatalf("audit count: %v", err) } if n != 1 { t.Fatalf("audit rows: got %d, want 1", n) } } func TestFleetUpdateStartConflictWhenAlreadyRunning(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) worker := &fakeFleetWorker{startErr: store.ErrFleetUpdateRunning} srv.deps.FleetWorker = worker cookie := loginAsAdmin(t, st) _ = helloOnlineHost(t, srv, st, "fu-host", "v0") req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet/update", bytes.NewReader([]byte(`{}`))) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusConflict { t.Fatalf("status: got %d, want 409", res.StatusCode) } body := readJSONError(t, res.Body) if body.Code != "fleet_update_in_progress" { t.Fatalf("code: %q", body.Code) } } func TestFleetUpdateStartDerivesHostIDsWhenEmpty(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) worker := &fakeFleetWorker{startID: ulid.Make().String()} srv.deps.FleetWorker = worker cookie := loginAsAdmin(t, st) // Two online + out-of-date, one online + at-target, one offline. a := helloOnlineHost(t, srv, st, "behind-a", "v0") b := helloOnlineHost(t, srv, st, "behind-b", "v0") _ = helloOnlineHost(t, srv, st, "uptodate", version.Version) offlineID := makeHost(t, st, "offline-host") if err := st.MarkHostHello(context.Background(), offlineID, "v0", "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("mark hello: %v", err) } // Don't MarkOnline → derivation should skip. req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet/update", bytes.NewReader([]byte(`{}`))) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusAccepted { t.Fatalf("status: got %d, want 202", res.StatusCode) } worker.mu.Lock() defer worker.mu.Unlock() if len(worker.startCalls) != 1 { t.Fatalf("start calls: %d", len(worker.startCalls)) } got := worker.startCalls[0].HostIDs want := map[string]bool{a: true, b: true} if len(got) != 2 || !want[got[0]] || !want[got[1]] { t.Fatalf("derived host_ids: got %v, want both of %v", got, []string{a, b}) } } func TestFleetUpdateCancelHappyPath(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) worker := &fakeFleetWorker{} srv.deps.FleetWorker = worker cookie := loginAsAdmin(t, st) // Seed a running fleet update directly. fuID := ulid.Make().String() uid := ulid.Make().String() if err := st.CreateUser(context.Background(), store.User{ ID: uid, Username: "starter", PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed user: %v", err) } hostID := makeHost(t, st, "fu-cancel-host") if err := st.CreateFleetUpdate(context.Background(), store.FleetUpdate{ID: fuID, StartedByUserID: uid, TargetVersion: "v1"}, []string{hostID}); err != nil { t.Fatalf("seed fleet update: %v", err) } req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet-updates/"+fuID+"/cancel", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusNoContent { t.Fatalf("status: got %d, want 204", res.StatusCode) } worker.mu.Lock() if len(worker.cancelCalls) != 1 || worker.cancelCalls[0] != fuID { t.Fatalf("cancel calls: %v", worker.cancelCalls) } worker.mu.Unlock() } func TestFleetUpdateCancelNotRunning(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) srv.deps.FleetWorker = &fakeFleetWorker{} cookie := loginAsAdmin(t, st) // Seed + complete one so it's no longer running. fuID := ulid.Make().String() uid := ulid.Make().String() _ = st.CreateUser(context.Background(), store.User{ ID: uid, Username: "starter2", PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(), }) hostID := makeHost(t, st, "fu-done-host") _ = st.CreateFleetUpdate(context.Background(), store.FleetUpdate{ID: fuID, StartedByUserID: uid, TargetVersion: "v1"}, []string{hostID}) if err := st.CompleteFleetUpdate(context.Background(), fuID, time.Now().UTC()); err != nil { t.Fatalf("complete: %v", err) } req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet-updates/"+fuID+"/cancel", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusConflict { t.Fatalf("status: got %d, want 409", res.StatusCode) } body := readJSONError(t, res.Body) if body.Code != "fleet_update_not_running" { t.Fatalf("code: %q", body.Code) } } func TestFleetUpdateGetHydrates(t *testing.T) { t.Parallel() _, ts, st := rawTestServer(t) cookie := loginAsAdmin(t, st) uid := ulid.Make().String() _ = st.CreateUser(context.Background(), store.User{ ID: uid, Username: "starter3", PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(), }) hostID := makeHost(t, st, "fu-get-host") fuID := ulid.Make().String() if err := st.CreateFleetUpdate(context.Background(), store.FleetUpdate{ID: fuID, StartedByUserID: uid, TargetVersion: "v1.2.3"}, []string{hostID}); err != nil { t.Fatalf("seed: %v", err) } req, _ := stdhttp.NewRequest("GET", ts.URL+"/api/fleet-updates/"+fuID, nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusOK { t.Fatalf("status: got %d, want 200", res.StatusCode) } var got fleetUpdateView if err := json.NewDecoder(res.Body).Decode(&got); err != nil { t.Fatalf("decode: %v", err) } if got.ID != fuID || got.TargetVersion != "v1.2.3" || got.Status != "running" { t.Fatalf("parent: %+v", got) } if len(got.Hosts) != 1 || got.Hosts[0].HostID != hostID || got.Hosts[0].HostName != "fu-get-host" { t.Fatalf("hosts: %+v", got.Hosts) } } func TestFleetUpdateRBAC(t *testing.T) { t.Parallel() _, ts, st := rawTestServer(t) for _, role := range []store.Role{store.RoleViewer, store.RoleOperator} { role := role t.Run(string(role), func(t *testing.T) { cookie := loginAsRole(t, st, role) req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/fleet/update", bytes.NewReader([]byte(`{}`))) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusForbidden { t.Fatalf("status: got %d, want 403", res.StatusCode) } }) } } // Sanity check that fakeFleetWorker satisfies the FleetWorker iface. var _ FleetWorker = (*fakeFleetWorker)(nil)