231 lines
5.8 KiB
Go
231 lines
5.8 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/oklog/ulid/v2"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
|
)
|
|
|
|
type fakeAlerts struct {
|
|
mu sync.Mutex
|
|
raised []string // hostIDs
|
|
resolved []string
|
|
reasons []string
|
|
}
|
|
|
|
func (f *fakeAlerts) RaiseUpdateFailed(_ context.Context, hostID, _ /*jobID*/, reason string, _ time.Time) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.raised = append(f.raised, hostID)
|
|
f.reasons = append(f.reasons, reason)
|
|
}
|
|
|
|
func (f *fakeAlerts) ResolveUpdateFailed(_ context.Context, hostID string, _ time.Time) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.resolved = append(f.resolved, hostID)
|
|
}
|
|
|
|
func seedJob(t *testing.T, st *store.Store, hostID string) string {
|
|
t.Helper()
|
|
jobID := ulid.Make().String()
|
|
if err := st.CreateJob(context.Background(), store.Job{
|
|
ID: jobID, HostID: hostID, Kind: "update",
|
|
ActorKind: "user", CreatedAt: time.Now().UTC(),
|
|
}); err != nil {
|
|
t.Fatalf("create job: %v", err)
|
|
}
|
|
return jobID
|
|
}
|
|
|
|
func TestUpdateWatcherOnHelloSuccess(t *testing.T) {
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
a := &fakeAlerts{}
|
|
w := NewUpdateWatcher(st, a, nil)
|
|
w.Track(jobID, hostID)
|
|
|
|
w.OnHello(context.Background(), hostID, "v2", "v2")
|
|
|
|
job, err := st.GetJob(context.Background(), jobID)
|
|
if err != nil {
|
|
t.Fatalf("get job: %v", err)
|
|
}
|
|
if job.Status != "succeeded" {
|
|
t.Fatalf("status: got %q want succeeded", job.Status)
|
|
}
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
if len(a.resolved) != 1 || a.resolved[0] != hostID {
|
|
t.Fatalf("resolve calls: %v", a.resolved)
|
|
}
|
|
if len(a.raised) != 0 {
|
|
t.Fatalf("unexpected raises: %v", a.raised)
|
|
}
|
|
}
|
|
|
|
func TestUpdateWatcherTimeout(t *testing.T) {
|
|
prev := updateTimeout
|
|
updateTimeout = 50 * time.Millisecond
|
|
t.Cleanup(func() { updateTimeout = prev })
|
|
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
a := &fakeAlerts{}
|
|
w := NewUpdateWatcher(st, a, nil)
|
|
w.Track(jobID, hostID)
|
|
|
|
time.Sleep(80 * time.Millisecond)
|
|
w.sweep(context.Background(), time.Now())
|
|
|
|
job, err := st.GetJob(context.Background(), jobID)
|
|
if err != nil {
|
|
t.Fatalf("get job: %v", err)
|
|
}
|
|
if job.Status != "failed" {
|
|
t.Fatalf("status: got %q want failed", job.Status)
|
|
}
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
if len(a.raised) != 1 || a.raised[0] != hostID {
|
|
t.Fatalf("raise calls: %v", a.raised)
|
|
}
|
|
if len(a.reasons) == 0 || a.reasons[0] == "" {
|
|
t.Fatalf("missing reason")
|
|
}
|
|
}
|
|
|
|
func TestUpdateWatcherMismatchedVersionNoOp(t *testing.T) {
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
a := &fakeAlerts{}
|
|
w := NewUpdateWatcher(st, a, nil)
|
|
w.Track(jobID, hostID)
|
|
|
|
w.OnHello(context.Background(), hostID, "v1", "v2")
|
|
|
|
job, _ := st.GetJob(context.Background(), jobID)
|
|
if job.Status == "succeeded" || job.Status == "failed" {
|
|
t.Fatalf("status flipped on mismatched hello: %q", job.Status)
|
|
}
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
if len(a.raised) != 0 || len(a.resolved) != 0 {
|
|
t.Fatalf("unexpected alert calls raised=%v resolved=%v", a.raised, a.resolved)
|
|
}
|
|
}
|
|
|
|
func TestUpdateWatcherHelloAfterTimeoutIsNoOp(t *testing.T) {
|
|
prev := updateTimeout
|
|
updateTimeout = 50 * time.Millisecond
|
|
t.Cleanup(func() { updateTimeout = prev })
|
|
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
a := &fakeAlerts{}
|
|
w := NewUpdateWatcher(st, a, nil)
|
|
w.Track(jobID, hostID)
|
|
|
|
time.Sleep(80 * time.Millisecond)
|
|
w.sweep(context.Background(), time.Now())
|
|
|
|
// Hello arrives after sweep — entry already gone, must be no-op.
|
|
w.OnHello(context.Background(), hostID, "v2", "v2")
|
|
|
|
job, _ := st.GetJob(context.Background(), jobID)
|
|
if job.Status != "failed" {
|
|
t.Fatalf("status flipped from failed → %q", job.Status)
|
|
}
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
if len(a.resolved) != 0 {
|
|
t.Fatalf("late hello triggered ResolveUpdateFailed: %v", a.resolved)
|
|
}
|
|
}
|
|
|
|
func TestUpdateWatcherOnHelloBroadcastsJobFinished(t *testing.T) {
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
hub := NewJobHub()
|
|
sub := hub.Register(jobID)
|
|
defer sub.unregister()
|
|
|
|
w := NewUpdateWatcher(st, &fakeAlerts{}, hub)
|
|
w.Track(jobID, hostID)
|
|
w.OnHello(context.Background(), hostID, "v2", "v2")
|
|
|
|
select {
|
|
case env := <-sub.ch:
|
|
if env.Type != api.MsgJobFinished {
|
|
t.Fatalf("envelope type: got %q want %q", env.Type, api.MsgJobFinished)
|
|
}
|
|
var p api.JobFinishedPayload
|
|
if err := env.UnmarshalPayload(&p); err != nil {
|
|
t.Fatalf("unmarshal payload: %v", err)
|
|
}
|
|
if p.JobID != jobID || p.Status != api.JobSucceeded {
|
|
t.Fatalf("payload: got %+v", p)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected synthetic job.finished broadcast, got nothing")
|
|
}
|
|
}
|
|
|
|
func TestUpdateWatcherTimeoutBroadcastsJobFinished(t *testing.T) {
|
|
prev := updateTimeout
|
|
updateTimeout = 50 * time.Millisecond
|
|
t.Cleanup(func() { updateTimeout = prev })
|
|
|
|
st := openWSTestStore(t)
|
|
hostID := ulid.Make().String()
|
|
seedHostWS(t, st, hostID)
|
|
jobID := seedJob(t, st, hostID)
|
|
|
|
hub := NewJobHub()
|
|
sub := hub.Register(jobID)
|
|
defer sub.unregister()
|
|
|
|
w := NewUpdateWatcher(st, &fakeAlerts{}, hub)
|
|
w.Track(jobID, hostID)
|
|
|
|
time.Sleep(80 * time.Millisecond)
|
|
w.sweep(context.Background(), time.Now())
|
|
|
|
select {
|
|
case env := <-sub.ch:
|
|
if env.Type != api.MsgJobFinished {
|
|
t.Fatalf("envelope type: got %q want %q", env.Type, api.MsgJobFinished)
|
|
}
|
|
var p api.JobFinishedPayload
|
|
if err := env.UnmarshalPayload(&p); err != nil {
|
|
t.Fatalf("unmarshal payload: %v", err)
|
|
}
|
|
if p.JobID != jobID || p.Status != api.JobFailed {
|
|
t.Fatalf("payload: got %+v", p)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected synthetic job.finished broadcast, got nothing")
|
|
}
|
|
}
|