package ws import ( "context" "sync" "testing" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) type fakeAlerts struct { mu sync.Mutex raised []string // hostIDs resolved []string reasons []string } func (f *fakeAlerts) RaiseUpdateFailed(_ context.Context, hostID, _ /*jobID*/, reason string, _ time.Time) { f.mu.Lock() defer f.mu.Unlock() f.raised = append(f.raised, hostID) f.reasons = append(f.reasons, reason) } func (f *fakeAlerts) ResolveUpdateFailed(_ context.Context, hostID string, _ time.Time) { f.mu.Lock() defer f.mu.Unlock() f.resolved = append(f.resolved, hostID) } func seedJob(t *testing.T, st *store.Store, hostID string) string { t.Helper() jobID := ulid.Make().String() if err := st.CreateJob(context.Background(), store.Job{ ID: jobID, HostID: hostID, Kind: "update", ActorKind: "user", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("create job: %v", err) } return jobID } func TestUpdateWatcherOnHelloSuccess(t *testing.T) { st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) a := &fakeAlerts{} w := NewUpdateWatcher(st, a, nil) w.Track(jobID, hostID) w.OnHello(context.Background(), hostID, "v2", "v2") job, err := st.GetJob(context.Background(), jobID) if err != nil { t.Fatalf("get job: %v", err) } if job.Status != "succeeded" { t.Fatalf("status: got %q want succeeded", job.Status) } a.mu.Lock() defer a.mu.Unlock() if len(a.resolved) != 1 || a.resolved[0] != hostID { t.Fatalf("resolve calls: %v", a.resolved) } if len(a.raised) != 0 { t.Fatalf("unexpected raises: %v", a.raised) } } func TestUpdateWatcherTimeout(t *testing.T) { prev := updateTimeout updateTimeout = 50 * time.Millisecond t.Cleanup(func() { updateTimeout = prev }) st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) a := &fakeAlerts{} w := NewUpdateWatcher(st, a, nil) w.Track(jobID, hostID) time.Sleep(80 * time.Millisecond) w.sweep(context.Background(), time.Now()) job, err := st.GetJob(context.Background(), jobID) if err != nil { t.Fatalf("get job: %v", err) } if job.Status != "failed" { t.Fatalf("status: got %q want failed", job.Status) } a.mu.Lock() defer a.mu.Unlock() if len(a.raised) != 1 || a.raised[0] != hostID { t.Fatalf("raise calls: %v", a.raised) } if len(a.reasons) == 0 || a.reasons[0] == "" { t.Fatalf("missing reason") } } func TestUpdateWatcherMismatchedVersionNoOp(t *testing.T) { st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) a := &fakeAlerts{} w := NewUpdateWatcher(st, a, nil) w.Track(jobID, hostID) w.OnHello(context.Background(), hostID, "v1", "v2") job, _ := st.GetJob(context.Background(), jobID) if job.Status == "succeeded" || job.Status == "failed" { t.Fatalf("status flipped on mismatched hello: %q", job.Status) } a.mu.Lock() defer a.mu.Unlock() if len(a.raised) != 0 || len(a.resolved) != 0 { t.Fatalf("unexpected alert calls raised=%v resolved=%v", a.raised, a.resolved) } } func TestUpdateWatcherHelloAfterTimeoutIsNoOp(t *testing.T) { prev := updateTimeout updateTimeout = 50 * time.Millisecond t.Cleanup(func() { updateTimeout = prev }) st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) a := &fakeAlerts{} w := NewUpdateWatcher(st, a, nil) w.Track(jobID, hostID) time.Sleep(80 * time.Millisecond) w.sweep(context.Background(), time.Now()) // Hello arrives after sweep — entry already gone, must be no-op. w.OnHello(context.Background(), hostID, "v2", "v2") job, _ := st.GetJob(context.Background(), jobID) if job.Status != "failed" { t.Fatalf("status flipped from failed → %q", job.Status) } a.mu.Lock() defer a.mu.Unlock() if len(a.resolved) != 0 { t.Fatalf("late hello triggered ResolveUpdateFailed: %v", a.resolved) } } func TestUpdateWatcherOnHelloBroadcastsJobFinished(t *testing.T) { st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) hub := NewJobHub() sub := hub.Register(jobID) defer sub.unregister() w := NewUpdateWatcher(st, &fakeAlerts{}, hub) w.Track(jobID, hostID) w.OnHello(context.Background(), hostID, "v2", "v2") select { case env := <-sub.ch: if env.Type != api.MsgJobFinished { t.Fatalf("envelope type: got %q want %q", env.Type, api.MsgJobFinished) } var p api.JobFinishedPayload if err := env.UnmarshalPayload(&p); err != nil { t.Fatalf("unmarshal payload: %v", err) } if p.JobID != jobID || p.Status != api.JobSucceeded { t.Fatalf("payload: got %+v", p) } case <-time.After(time.Second): t.Fatal("expected synthetic job.finished broadcast, got nothing") } } func TestUpdateWatcherTimeoutBroadcastsJobFinished(t *testing.T) { prev := updateTimeout updateTimeout = 50 * time.Millisecond t.Cleanup(func() { updateTimeout = prev }) st := openWSTestStore(t) hostID := ulid.Make().String() seedHostWS(t, st, hostID) jobID := seedJob(t, st, hostID) hub := NewJobHub() sub := hub.Register(jobID) defer sub.unregister() w := NewUpdateWatcher(st, &fakeAlerts{}, hub) w.Track(jobID, hostID) time.Sleep(80 * time.Millisecond) w.sweep(context.Background(), time.Now()) select { case env := <-sub.ch: if env.Type != api.MsgJobFinished { t.Fatalf("envelope type: got %q want %q", env.Type, api.MsgJobFinished) } var p api.JobFinishedPayload if err := env.UnmarshalPayload(&p); err != nil { t.Fatalf("unmarshal payload: %v", err) } if p.JobID != jobID || p.Status != api.JobFailed { t.Fatalf("payload: got %+v", p) } case <-time.After(time.Second): t.Fatal("expected synthetic job.finished broadcast, got nothing") } }