server: serialize DrainPending per host (avoid drain double-dispatch)

Add a per-host drain mutex (drainLocks map guarded by drainLocksMu) on
the Server struct. DrainPending acquires it with TryLock: if a drain is
already in-flight for this host, the call returns immediately — the
running drain will see every pending row. This prevents the on-hello
goroutine and the 30s tick from both listing the same host's rows and
dispatching them twice.

Update three existing tests that called srv.DrainPending explicitly
after the on-hello goroutine had already been spawned: replace the
now-redundant direct call with a waitForPendingCount poll so they don't
race the goroutine's mutex ownership. Add TestDrainPendingSerializesPerHost
which fires 10 concurrent DrainPending goroutines against a 5-row queue
and asserts exactly 5 job rows result.
This commit is contained in:
2026-05-04 00:33:13 +01:00
parent 9ec69456fe
commit adece5eb72
3 changed files with 141 additions and 7 deletions
+102 -6
View File
@@ -6,6 +6,7 @@ package http
import (
"context"
"encoding/json"
"sync"
"testing"
"time"
@@ -57,6 +58,23 @@ func countPendingForHost(t *testing.T, st *store.Store, hostID string) int {
return n
}
// waitForPendingCount polls until the pending_runs count for hostID
// reaches wantN or the deadline expires. Use this instead of calling
// DrainPending synchronously when the test relies on the on-hello
// goroutine (which holds the per-host drain mutex) to process rows.
func waitForPendingCount(t *testing.T, st *store.Store, hostID string, wantN int, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if countPendingForHost(t, st, hostID) == wantN {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Errorf("pending count for host %s: want %d after %v, got %d",
hostID, wantN, timeout, countPendingForHost(t, st, hostID))
}
// countAuditAction returns the number of audit_log rows with the given action.
func countAuditAction(t *testing.T, st *store.Store, action string) int {
t.Helper()
@@ -164,10 +182,11 @@ func TestDrainPendingAbandonsOnRetryMax(t *testing.T) {
sendHello(t, c, "abandon-retry-host")
_ = drainUntil(t, c, api.MsgScheduleSet)
// Call DrainPending directly — gives us deterministic completion.
conn := connFromHub(t, srv, hostID)
_ = conn // just to ensure conn was registered
srv.DrainPending(context.Background(), hostID)
// The on-hello goroutine processes the row (retry_max exceeded → abandon).
// Wait for it to finish rather than calling DrainPending directly, which
// would be a no-op while the goroutine holds the per-host drain mutex.
_ = connFromHub(t, srv, hostID) // ensure hub registration
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
if n := countPendingForHost(t, st, hostID); n != 0 {
t.Errorf("pending rows after abandon: got %d, want 0", n)
@@ -309,7 +328,10 @@ func TestDrainPendingDropsRowsForGoneSchedule(t *testing.T) {
sendHello(t, c, "gone-sched-host")
_ = drainUntil(t, c, api.MsgScheduleSet)
srv.DrainPending(context.Background(), hostID)
// The on-hello goroutine processes the row (disabled schedule → abandon).
// Poll for completion instead of calling DrainPending, which would return
// immediately while the goroutine holds the per-host drain mutex.
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
if n := countPendingForHost(t, st, hostID); n != 0 {
t.Errorf("pending rows after schedule-gone abandon: got %d, want 0", n)
@@ -387,7 +409,10 @@ func TestDrainPendingDropsRowsForGoneSourceGroup(t *testing.T) {
sendHello(t, c, "gone-group-host")
_ = drainUntil(t, c, api.MsgScheduleSet)
srv.DrainPending(context.Background(), hostID)
// The on-hello goroutine processes the row (source group gone → abandon).
// Poll for completion instead of calling DrainPending, which would return
// immediately while the goroutine holds the per-host drain mutex.
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
if n := countPendingForHost(t, st, hostID); n != 0 {
t.Errorf("pending rows after source-group-gone abandon: got %d, want 0", n)
@@ -469,3 +494,74 @@ func TestEnqueueOnDispatchFailure(t *testing.T) {
t.Errorf("last_error empty")
}
}
// TestDrainPendingSerializesPerHost verifies that concurrent DrainPending
// calls for the same host do not double-dispatch pending rows. The per-host
// mutex (TryLock semantics) means exactly one drain processes each row.
func TestDrainPendingSerializesPerHost(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
hostID, token := enrolHostForWS(t, srv, st, "serialize-host")
gid, sid := seedSchedAndGroup(t, st, hostID, 10)
// Connect the agent so DrainPending can dispatch.
c := agentDial(t, srv, ts, hostID, token)
sendHello(t, c, "serialize-host")
// Drain the on-hello goroutine's pass first (no pending rows yet),
// then wait for the schedule.set so the connection is fully settled.
_ = drainUntil(t, c, api.MsgScheduleSet)
// Insert 5 pending rows now that the on-hello drain has already run.
now := time.Now().UTC()
for i := range 5 {
pid := ulid.Make().String()
if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{
ID: pid,
ScheduleID: sid,
SourceGroupID: gid,
HostID: hostID,
Attempt: 1,
NextAttemptAt: now.Add(-time.Second),
ScheduledAt: now.Add(-time.Duration(i+1) * time.Minute),
}); err != nil {
t.Fatalf("enqueue row %d: %v", i, err)
}
}
// Spawn 10 goroutines all calling DrainPending concurrently.
var wg sync.WaitGroup
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
srv.DrainPending(context.Background(), hostID)
}()
}
wg.Wait()
// Drain any envelopes the agent received so we don't block below.
// We read with short timeouts and stop when the connection goes quiet.
drainDeadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(drainDeadline) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_, _, err := c.Read(ctx)
cancel()
if err != nil {
break
}
}
// All 5 pending rows must be gone.
if n := countPendingForHost(t, st, hostID); n != 0 {
t.Errorf("pending rows after concurrent drain: got %d, want 0", n)
}
// Exactly 5 backup job rows (one per pending row), not 10+ from a race.
var n int
_ = st.DB().QueryRow(
`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule'`,
hostID).Scan(&n)
if n != 5 {
t.Errorf("backup job rows: got %d, want 5 (per-host mutex must prevent double-dispatch)", n)
}
}