// pending_drain_test.go — covers DrainPending / DrainAllDue and the // onAgentHello goroutine spawn that drains a freshly-reconnected // host's queue. package http import ( "context" "encoding/json" "sync" "testing" "time" "github.com/coder/websocket" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // seedSchedAndGroup wires up a host with one source group + one // schedule pointing at it. Returns (groupID, scheduleID). func seedSchedAndGroup(t *testing.T, st *store.Store, hostID string, retryMax int) (string, string) { t.Helper() gid := ulid.Make().String() if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{ ID: gid, HostID: hostID, Name: "default", Includes: []string{"/etc"}, RetryMax: retryMax, RetryBackoffSeconds: 60, }); err != nil { t.Fatalf("create group: %v", err) } sid := ulid.Make().String() if err := st.CreateSchedule(context.Background(), &store.Schedule{ ID: sid, HostID: hostID, CronExpr: "0 3 * * *", Enabled: true, SourceGroupIDs: []string{gid}, }); err != nil { t.Fatalf("create schedule: %v", err) } // Mark a successful init job so auto-init doesn't pollute reads. if err := st.CreateJob(context.Background(), store.Job{ ID: ulid.Make().String(), HostID: hostID, Kind: "init", ActorKind: "system", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed init: %v", err) } return gid, sid } // countPendingForHost returns the number of pending_runs rows for hostID. func countPendingForHost(t *testing.T, st *store.Store, hostID string) int { t.Helper() var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM pending_runs WHERE host_id = ?`, hostID).Scan(&n); err != nil { t.Fatalf("count pending: %v", err) } return n } // waitForPendingCount polls until the pending_runs count for hostID // reaches wantN or the deadline expires. Use this instead of calling // DrainPending synchronously when the test relies on the on-hello // goroutine (which holds the per-host drain mutex) to process rows. func waitForPendingCount(t *testing.T, st *store.Store, hostID string, wantN int, timeout time.Duration) { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { if countPendingForHost(t, st, hostID) == wantN { return } time.Sleep(20 * time.Millisecond) } t.Errorf("pending count for host %s: want %d after %v, got %d", hostID, wantN, timeout, countPendingForHost(t, st, hostID)) } // countAuditAction returns the number of audit_log rows with the given action. func countAuditAction(t *testing.T, st *store.Store, action string) int { t.Helper() var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM audit_log WHERE action = ?`, action).Scan(&n); err != nil { t.Fatalf("count audit: %v", err) } return n } func TestDrainPendingDispatchesOnReconnect(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "drain-host") gid, sid := seedSchedAndGroup(t, st, hostID, 5) // Pre-insert a pending row that's already due. The on-hello // goroutine should drain it after we connect. pendingID := ulid.Make().String() now := time.Now().UTC() if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }); err != nil { t.Fatalf("enqueue: %v", err) } c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "drain-host") // Walk envelopes looking for a backup command.run carrying the // group's includes. var got *api.CommandRunPayload deadline := time.Now().Add(3 * time.Second) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { continue } if env.Type != api.MsgCommandRun { continue } var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind == api.JobBackup { got = &p break } } if got == nil { t.Fatalf("no backup command.run dispatched after reconnect drain") } if !equalStrings(got.Includes, []string{"/etc"}) { t.Errorf("backup includes: %v", got.Includes) } if got.Tag != "default" { t.Errorf("backup tag: %q", got.Tag) } // Pending row should be gone. if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after drain: got %d, want 0", n) } // One backup job row landed (in addition to the seeded init). var n int _ = st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule'`, hostID).Scan(&n) if n != 1 { t.Errorf("backup job rows: got %d, want 1", n) } } func TestDrainPendingAbandonsOnRetryMax(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "abandon-retry-host") gid, sid := seedSchedAndGroup(t, st, hostID, 2) pendingID := ulid.Make().String() now := time.Now().UTC() if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 2, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }); err != nil { t.Fatalf("enqueue: %v", err) } auditBefore := countAuditAction(t, st, "pending_run.abandoned") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "abandon-retry-host") _ = drainUntil(t, c, api.MsgScheduleSet) // The on-hello goroutine processes the row (retry_max exceeded → abandon). // Wait for it to finish rather than calling DrainPending directly, which // would be a no-op while the goroutine holds the per-host drain mutex. _ = connFromHub(t, srv, hostID) // ensure hub registration waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after abandon: got %d, want 0", n) } if d := countAuditAction(t, st, "pending_run.abandoned") - auditBefore; d != 1 { t.Errorf("audit pending_run.abandoned delta: got %d, want 1", d) } // No backup command.run should have been sent. deadline := time.Now().Add(400 * time.Millisecond) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope _ = json.Unmarshal(raw, &env) if env.Type == api.MsgCommandRun { var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind == api.JobBackup { t.Fatalf("abandoned row still dispatched a backup: %+v", p) } } } // No backup job row. var n int _ = st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup'`, hostID).Scan(&n) if n != 0 { t.Errorf("abandon path created a backup job: %d rows", n) } } func TestDrainPendingBumpsOnSendFailure(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "bump-host") gid, sid := seedSchedAndGroup(t, st, hostID, 5) c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "bump-host") _ = drainUntil(t, c, api.MsgScheduleSet) // Capture the conn before closing the client side. Hub.Conn still // returns it after the client-side close — the server's Unregister // fires when its read loop sees the close, but the conn ptr remains // valid; subsequent Sends just fail. conn := connFromHub(t, srv, hostID) if conn == nil { t.Fatal("conn never registered") } // Insert the pending row AFTER the on-hello drain goroutine has // already scanned (an empty list) — otherwise we race the on-hello // drain dispatching the row over the still-live socket. pendingID := ulid.Make().String() now := time.Now().UTC() if err := c.Close(websocket.StatusNormalClosure, "test"); err != nil { t.Fatalf("close: %v", err) } // Brief settle so the close is observed by the server's read loop. time.Sleep(150 * time.Millisecond) if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }); err != nil { t.Fatalf("enqueue: %v", err) } // DrainPending uses Hub.Conn(hostID); after the client close the // server may have unregistered already. Call drainOne directly // against the captured conn so we deterministically exercise the // "Send fails" branch rather than the "host gone" branch. srv.drainOne(context.Background(), conn, store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }) // The original row must be bumped to attempt=2 with a non-empty // last_error. Critically, NO duplicate row should have been created: // drainOne calls dispatchBackupForGroupCore (not dispatchBackupForGroup) // so the enqueue-on-failure path is bypassed and the count stays at 1. if n := countPendingForHost(t, st, hostID); n != 1 { t.Errorf("pending rows after send failure: got %d, want 1 (no duplicate enqueue)", n) } var attempt int var lastErr string if err := st.DB().QueryRow( `SELECT attempt, COALESCE(last_error,'') FROM pending_runs WHERE id = ?`, pendingID).Scan(&attempt, &lastErr); err != nil { t.Fatalf("scan original row: %v", err) } if attempt != 2 { t.Errorf("attempt after bump: got %d, want 2", attempt) } if lastErr == "" { t.Errorf("last_error empty after bump") } } func TestDrainPendingDropsRowsForGoneSchedule(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "gone-sched-host") gid, sid := seedSchedAndGroup(t, st, hostID, 5) pendingID := ulid.Make().String() now := time.Now().UTC() if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }); err != nil { t.Fatalf("enqueue: %v", err) } // Disable the schedule. (Deleting it would FK-cascade-delete the // pending_runs row out from under the drainer, which is fine for // production but defeats the point of the test. The // disabled-schedule path goes through the same abandonPending code, // so it's an equivalent assertion.) if _, err := st.DB().Exec( `UPDATE schedules SET enabled = 0 WHERE id = ?`, sid); err != nil { t.Fatalf("disable schedule: %v", err) } auditBefore := countAuditAction(t, st, "pending_run.abandoned") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "gone-sched-host") _ = drainUntil(t, c, api.MsgScheduleSet) // The on-hello goroutine processes the row (disabled schedule → abandon). // Poll for completion instead of calling DrainPending, which would return // immediately while the goroutine holds the per-host drain mutex. waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after schedule-gone abandon: got %d, want 0", n) } if d := countAuditAction(t, st, "pending_run.abandoned") - auditBefore; d != 1 { t.Errorf("audit delta: got %d, want 1", d) } // Drain produced no backup envelope. deadline := time.Now().Add(400 * time.Millisecond) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope _ = json.Unmarshal(raw, &env) if env.Type == api.MsgCommandRun { var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind == api.JobBackup { t.Fatalf("gone-schedule abandon still dispatched: %+v", p) } } } } // TestDrainPendingDropsRowsForGoneSourceGroup verifies that when a // source group is gone (ErrNotFound) the pending row is abandoned and // an audit entry is written. Transient-error paths (SQLITE_BUSY, // context cancellation) are not covered here because the real *Store // doesn't expose a fault-injection seam; the code-review check above // is the gate for that path. func TestDrainPendingDropsRowsForGoneSourceGroup(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "gone-group-host") _, sid := seedSchedAndGroup(t, st, hostID, 5) // Use a source_group_id that never existed. pending_runs carries a // FK to source_groups, so we must bypass FK enforcement for this // insert. PRAGMA foreign_keys is connection-scoped and can only be // changed outside a transaction; DB().Exec runs on an arbitrary // pooled connection, so we pin it with a dedicated *sql.Conn. fakeGroupID := ulid.Make().String() pendingID := ulid.Make().String() now := time.Now().UTC() conn, err := st.DB().Conn(context.Background()) if err != nil { t.Fatalf("db conn: %v", err) } defer conn.Close() if _, err := conn.ExecContext(context.Background(), `PRAGMA foreign_keys = OFF`); err != nil { t.Fatalf("fk off: %v", err) } if _, err := conn.ExecContext(context.Background(), `INSERT INTO pending_runs (id, schedule_id, source_group_id, host_id, attempt, next_attempt_at, scheduled_at) VALUES (?, ?, ?, ?, 1, ?, ?)`, pendingID, sid, fakeGroupID, hostID, now.Add(-time.Second), now.Add(-time.Minute), ); err != nil { t.Fatalf("insert pending: %v", err) } if _, err := conn.ExecContext(context.Background(), `PRAGMA foreign_keys = ON`); err != nil { t.Fatalf("fk on: %v", err) } auditBefore := countAuditAction(t, st, "pending_run.abandoned") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "gone-group-host") _ = drainUntil(t, c, api.MsgScheduleSet) // The on-hello goroutine processes the row (source group gone → abandon). // Poll for completion instead of calling DrainPending, which would return // immediately while the goroutine holds the per-host drain mutex. waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after source-group-gone abandon: got %d, want 0", n) } if d := countAuditAction(t, st, "pending_run.abandoned") - auditBefore; d != 1 { t.Errorf("audit delta: got %d, want 1", d) } } func TestDrainAllDueSkipsOfflineHosts(t *testing.T) { t.Parallel() srv, _, st := rawTestServer(t) // Don't dial — host is enrolled but never connected. hostID, _ := enrolHostForWS(t, srv, st, "offline-host") gid, sid := seedSchedAndGroup(t, st, hostID, 5) pendingID := ulid.Make().String() now := time.Now().UTC() if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pendingID, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Minute), }); err != nil { t.Fatalf("enqueue: %v", err) } auditBefore := countAuditAction(t, st, "pending_run.abandoned") srv.DrainAllDue(context.Background()) // Row still there (host offline, drainer skips). if n := countPendingForHost(t, st, hostID); n != 1 { t.Errorf("pending rows after DrainAllDue against offline host: got %d, want 1", n) } if d := countAuditAction(t, st, "pending_run.abandoned") - auditBefore; d != 0 { t.Errorf("audit unexpectedly changed: delta %d", d) } } func TestEnqueueOnDispatchFailure(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "enqueue-host") _, sid := seedSchedAndGroup(t, st, hostID, 5) c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "enqueue-host") _ = drainUntil(t, c, api.MsgScheduleSet) conn := connFromHub(t, srv, hostID) _ = conn // Close the client side so the server's next Send errors. if err := c.Close(websocket.StatusNormalClosure, "test"); err != nil { t.Fatalf("close: %v", err) } time.Sleep(100 * time.Millisecond) scheduledAt := time.Now().UTC().Add(-30 * time.Second) srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, scheduledAt) // One pending row should have been enqueued (attempt=1) with the // scheduled_at preserved. rows, err := st.ListPendingRunsForHost(context.Background(), hostID) if err != nil { t.Fatalf("list: %v", err) } if len(rows) != 1 { t.Fatalf("pending rows: got %d, want 1", len(rows)) } if rows[0].Attempt != 1 { t.Errorf("attempt: got %d, want 1", rows[0].Attempt) } // scheduled_at preserved (within RFC3339Nano round-trip tolerance). if rows[0].ScheduledAt.Sub(scheduledAt).Abs() > time.Microsecond { t.Errorf("scheduled_at drift: %v vs %v", rows[0].ScheduledAt, scheduledAt) } if rows[0].LastError == "" { t.Errorf("last_error empty") } } // TestDrainPendingSerializesPerHost verifies that concurrent DrainPending // calls for the same host do not double-dispatch pending rows. The per-host // mutex (TryLock semantics) means exactly one drain processes each row. func TestDrainPendingSerializesPerHost(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "serialize-host") gid, sid := seedSchedAndGroup(t, st, hostID, 10) // Connect the agent so DrainPending can dispatch. c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "serialize-host") // Drain the on-hello goroutine's pass first (no pending rows yet), // then wait for the schedule.set so the connection is fully settled. _ = drainUntil(t, c, api.MsgScheduleSet) // Insert 5 pending rows now that the on-hello drain has already run. now := time.Now().UTC() for i := range 5 { pid := ulid.Make().String() if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ ID: pid, ScheduleID: sid, SourceGroupID: gid, HostID: hostID, Attempt: 1, NextAttemptAt: now.Add(-time.Second), ScheduledAt: now.Add(-time.Duration(i+1) * time.Minute), }); err != nil { t.Fatalf("enqueue row %d: %v", i, err) } } // Spawn 10 goroutines all calling DrainPending concurrently. var wg sync.WaitGroup for range 10 { wg.Add(1) go func() { defer wg.Done() srv.DrainPending(context.Background(), hostID) }() } wg.Wait() // Drain any envelopes the agent received so we don't block below. // We read with short timeouts and stop when the connection goes quiet. drainDeadline := time.Now().Add(500 * time.Millisecond) for time.Now().Before(drainDeadline) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) _, _, err := c.Read(ctx) cancel() if err != nil { break } } // All 5 pending rows must be gone. if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after concurrent drain: got %d, want 0", n) } // Exactly 5 backup job rows (one per pending row), not 10+ from a race. var n int _ = st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule'`, hostID).Scan(&n) if n != 5 { t.Errorf("backup job rows: got %d, want 5 (per-host mutex must prevent double-dispatch)", n) } }