package fleetupdate import ( "context" "errors" "path/filepath" "sync" "testing" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) type fakeHub struct { mu sync.Mutex online map[string]bool } func (f *fakeHub) Connected(hostID string) bool { f.mu.Lock() defer f.mu.Unlock() return f.online[hostID] } type fakeDispatcher struct { mu sync.Mutex calls []string // host IDs // after dispatch, set the host's agent_version to this on the // store so the worker observes the version transition. st *store.Store target string delayMS int failOnHost map[string]string // host → error code } func (f *fakeDispatcher) DispatchUpdate(ctx context.Context, hostID, _ string) (string, string, error) { f.mu.Lock() f.calls = append(f.calls, hostID) if code, ok := f.failOnHost[hostID]; ok { f.mu.Unlock() return "", code, nil } st := f.st target := f.target delay := f.delayMS f.mu.Unlock() jobID := ulid.Make().String() if st != nil { _ = st.CreateJob(context.Background(), store.Job{ ID: jobID, HostID: hostID, Kind: "update", ActorKind: "user", CreatedAt: time.Now().UTC(), }) } if st != nil && target != "" { go func() { if delay > 0 { time.Sleep(time.Duration(delay) * time.Millisecond) } _ = st.MarkHostHello(context.Background(), hostID, target, "0.17", api.CurrentProtocolVersion, time.Now().UTC()) }() } return jobID, "", nil } type recAlert struct { mu sync.Mutex reasons []string } func (r *recAlert) RaiseFleetUpdateHalted(_ context.Context, _ string, reason string, _ time.Time) { r.mu.Lock() r.reasons = append(r.reasons, reason) r.mu.Unlock() } func openStore(t *testing.T) *store.Store { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("open: %v", err) } t.Cleanup(func() { _ = st.Close() }) return st } func mustCreateAdmin(t *testing.T, st *store.Store) string { t.Helper() uid := ulid.Make().String() if err := st.CreateUser(context.Background(), store.User{ ID: uid, Username: "u-" + uid[:6], PasswordHash: "x", Role: store.RoleAdmin, CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("user: %v", err) } return uid } func mustCreateHost(t *testing.T, st *store.Store, name, version string) string { t.Helper() hostID := ulid.Make().String() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: name, OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, "deadbeef-"+hostID, ""); err != nil { t.Fatalf("host: %v", err) } if version != "" { if err := st.MarkHostHello(context.Background(), hostID, version, "0.17", api.CurrentProtocolVersion, time.Now().UTC()); err != nil { t.Fatalf("hello: %v", err) } } return hostID } func waitForStatus(t *testing.T, st *store.Store, fuID, want string, timeout time.Duration) *store.FleetUpdate { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { fu, _, err := st.GetFleetUpdate(context.Background(), fuID) if err == nil && fu != nil && fu.Status == want { return fu } time.Sleep(20 * time.Millisecond) } t.Fatalf("status never reached %q", want) return nil } func TestWorkerTwoHostsBothSucceed(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v0") h2 := mustCreateHost(t, st, "h2", "v0") hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} disp := &fakeDispatcher{st: st, target: "v2", delayMS: 30} alerts := &recAlert{} w := NewWorker(st, hub, disp, alerts) w.pollPeriod = 20 * time.Millisecond w.hostTimeout = 2 * time.Second fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) if err != nil { t.Fatalf("start: %v", err) } waitForStatus(t, st, fuID, "completed", 5*time.Second) _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) for _, h := range hosts { if h.Status != "succeeded" { t.Errorf("host %s status %q want succeeded", h.HostID, h.Status) } } if n := len(alerts.reasons); n != 0 { t.Errorf("unexpected halt alert: %v", alerts.reasons) } } func TestWorkerSecondHostTimesOutHalts(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v0") h2 := mustCreateHost(t, st, "h2", "v0") h3 := mustCreateHost(t, st, "h3", "v0") hub := &fakeHub{online: map[string]bool{h1: true, h2: true, h3: true}} // h1 dispatches normally (transitions to v2). h2 dispatch returns // success but never transitions. disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20, failOnHost: map[string]string{ h2: "", // not a code-failure; simulate by clearing target on this disp run }} // Actually: drop h2 from the auto-transition by faking with a // per-host store setter. Easiest: subclass via a wrapper. _ = disp customDisp := &perHostDispatcher{base: disp, st: st, target: "v2", noTransition: map[string]bool{h2: true}} alerts := &recAlert{} w := NewWorker(st, hub, customDisp, alerts) w.pollPeriod = 20 * time.Millisecond w.hostTimeout = 200 * time.Millisecond fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2, h3}) if err != nil { t.Fatalf("start: %v", err) } waitForStatus(t, st, fuID, "halted", 3*time.Second) _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) gotStatus := map[string]string{} for _, h := range hosts { gotStatus[h.HostID] = h.Status } if gotStatus[h1] != "succeeded" { t.Errorf("h1: %q", gotStatus[h1]) } if gotStatus[h2] != "failed" { t.Errorf("h2: %q", gotStatus[h2]) } if gotStatus[h3] != "pending" { t.Errorf("h3: %q", gotStatus[h3]) } alerts.mu.Lock() defer alerts.mu.Unlock() if len(alerts.reasons) != 1 { t.Errorf("alert reasons: %v", alerts.reasons) } } // perHostDispatcher lets a test omit the auto-transition for selected // hosts so we can simulate timeout. type perHostDispatcher struct { mu sync.Mutex base *fakeDispatcher st *store.Store target string noTransition map[string]bool } func (p *perHostDispatcher) DispatchUpdate(_ context.Context, hostID, _ string) (string, string, error) { p.mu.Lock() skip := p.noTransition[hostID] p.mu.Unlock() jobID := ulid.Make().String() _ = p.st.CreateJob(context.Background(), store.Job{ ID: jobID, HostID: hostID, Kind: "update", ActorKind: "user", CreatedAt: time.Now().UTC(), }) if !skip { go func() { time.Sleep(20 * time.Millisecond) _ = p.st.MarkHostHello(context.Background(), hostID, p.target, "0.17", api.CurrentProtocolVersion, time.Now().UTC()) }() } return jobID, "", nil } func TestWorkerHostOfflineHalts(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v0") h2 := mustCreateHost(t, st, "h2", "v0") hub := &fakeHub{online: map[string]bool{h1: false, h2: true}} disp := &fakeDispatcher{st: st, target: "v2"} alerts := &recAlert{} w := NewWorker(st, hub, disp, alerts) w.pollPeriod = 20 * time.Millisecond w.hostTimeout = 500 * time.Millisecond fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) if err != nil { t.Fatalf("start: %v", err) } waitForStatus(t, st, fuID, "halted", 2*time.Second) _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) if hosts[0].Status != "failed" { t.Errorf("h1 status: %q", hosts[0].Status) } if hosts[1].Status != "pending" { t.Errorf("h2 status: %q", hosts[1].Status) } } func TestWorkerAlreadyAtTargetSkipped(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v2") h2 := mustCreateHost(t, st, "h2", "v0") hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} disp := &fakeDispatcher{st: st, target: "v2", delayMS: 20} alerts := &recAlert{} w := NewWorker(st, hub, disp, alerts) w.pollPeriod = 20 * time.Millisecond w.hostTimeout = 2 * time.Second fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) if err != nil { t.Fatalf("start: %v", err) } waitForStatus(t, st, fuID, "completed", 4*time.Second) _, hosts, _ := st.GetFleetUpdate(context.Background(), fuID) want := map[string]string{h1: "skipped", h2: "succeeded"} for _, h := range hosts { if h.Status != want[h.HostID] { t.Errorf("host %s: got %q want %q", h.HostID, h.Status, want[h.HostID]) } } } func TestWorkerCancelMidRun(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v0") h2 := mustCreateHost(t, st, "h2", "v0") hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} // h1's transition is delayed long enough that we can cancel // before it lands; h2 should never be touched. disp := &fakeDispatcher{st: st, target: "v2", delayMS: 500} alerts := &recAlert{} w := NewWorker(st, hub, disp, alerts) w.pollPeriod = 50 * time.Millisecond w.hostTimeout = 5 * time.Second fuID, err := w.Start(context.Background(), uid, "v2", []string{h1, h2}) if err != nil { t.Fatalf("start: %v", err) } // Give the worker a moment to dispatch h1. time.Sleep(100 * time.Millisecond) if err := w.Cancel(context.Background(), fuID); err != nil { t.Fatalf("cancel: %v", err) } waitForStatus(t, st, fuID, "cancelled", 2*time.Second) // h2 should never be dispatched. disp.mu.Lock() defer disp.mu.Unlock() for _, c := range disp.calls { if c == h2 { t.Errorf("h2 dispatched after cancel") } } } func TestWorkerStartWhileActiveErrors(t *testing.T) { st := openStore(t) uid := mustCreateAdmin(t, st) h1 := mustCreateHost(t, st, "h1", "v0") h2 := mustCreateHost(t, st, "h2", "v0") hub := &fakeHub{online: map[string]bool{h1: true, h2: true}} disp := &fakeDispatcher{st: st, target: "v2", delayMS: 5_000} w := NewWorker(st, hub, disp, &recAlert{}) w.pollPeriod = 50 * time.Millisecond w.hostTimeout = 2 * time.Second if _, err := w.Start(context.Background(), uid, "v2", []string{h1}); err != nil { t.Fatalf("first start: %v", err) } _, err := w.Start(context.Background(), uid, "v2", []string{h2}) if !errors.Is(err, store.ErrFleetUpdateRunning) { t.Fatalf("err: %v want ErrFleetUpdateRunning", err) } }