From 99ef2b7a7195c299ff6de704abaa8f89a258c313 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Mon, 4 May 2026 00:33:13 +0100 Subject: [PATCH] server: serialize DrainPending per host (avoid drain double-dispatch) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a per-host drain mutex (drainLocks map guarded by drainLocksMu) on the Server struct. DrainPending acquires it with TryLock: if a drain is already in-flight for this host, the call returns immediately — the running drain will see every pending row. This prevents the on-hello goroutine and the 30s tick from both listing the same host's rows and dispatching them twice. Update three existing tests that called srv.DrainPending explicitly after the on-hello goroutine had already been spawned: replace the now-redundant direct call with a waitForPendingCount poll so they don't race the goroutine's mutex ownership. Add TestDrainPendingSerializesPerHost which fires 10 concurrent DrainPending goroutines against a 5-row queue and asserts exactly 5 job rows result. --- internal/server/http/pending_drain.go | 30 ++++++ internal/server/http/pending_drain_test.go | 108 +++++++++++++++++++-- internal/server/http/server.go | 10 +- 3 files changed, 141 insertions(+), 7 deletions(-) diff --git a/internal/server/http/pending_drain.go b/internal/server/http/pending_drain.go index 5433794..a69116d 100644 --- a/internal/server/http/pending_drain.go +++ b/internal/server/http/pending_drain.go @@ -13,6 +13,7 @@ import ( "context" "errors" "log/slog" + "sync" "time" "github.com/oklog/ulid/v2" @@ -34,7 +35,18 @@ const ( // row is dropped (audit-logged as abandoned). Otherwise we attempt // dispatch; success deletes the row, failure bumps the attempt and // reschedules with exponential backoff. +// +// A per-host mutex (hostDrainMutex) ensures that the on-hello goroutine +// and the 30s tick cannot process the same host concurrently. If a drain +// is already in-flight for this host, the call returns immediately — the +// running drain will see any rows we'd have processed. func (s *Server) DrainPending(ctx context.Context, hostID string) { + mu := s.hostDrainMutex(hostID) + if !mu.TryLock() { + return + } + defer mu.Unlock() + runs, err := s.deps.Store.ListPendingRunsForHost(ctx, hostID) if err != nil { slog.Warn("drain pending: list", "host_id", hostID, "err", err) @@ -148,6 +160,24 @@ func (s *Server) abandonPending(ctx context.Context, p store.PendingRun, reason } } +// hostDrainMutex returns the per-host mutex for DrainPending, +// creating it on first request. The map is guarded by drainLocksMu. +// Mutex objects are never deleted from the map — there are at most +// len(hosts) entries, which is bounded by the fleet size. +func (s *Server) hostDrainMutex(hostID string) *sync.Mutex { + s.drainLocksMu.Lock() + defer s.drainLocksMu.Unlock() + if s.drainLocks == nil { + s.drainLocks = make(map[string]*sync.Mutex) + } + mu, ok := s.drainLocks[hostID] + if !ok { + mu = &sync.Mutex{} + s.drainLocks[hostID] = mu + } + return mu +} + // DrainAllDue is the 30s-ticker entrypoint. Walks rows whose // next_attempt_at <= now (DuePendingRuns), dedupes by host, and calls // DrainPending per host. The DrainPending then re-walks the host's diff --git a/internal/server/http/pending_drain_test.go b/internal/server/http/pending_drain_test.go index 12c435b..a216c25 100644 --- a/internal/server/http/pending_drain_test.go +++ b/internal/server/http/pending_drain_test.go @@ -6,6 +6,7 @@ package http import ( "context" "encoding/json" + "sync" "testing" "time" @@ -57,6 +58,23 @@ func countPendingForHost(t *testing.T, st *store.Store, hostID string) int { return n } +// waitForPendingCount polls until the pending_runs count for hostID +// reaches wantN or the deadline expires. Use this instead of calling +// DrainPending synchronously when the test relies on the on-hello +// goroutine (which holds the per-host drain mutex) to process rows. +func waitForPendingCount(t *testing.T, st *store.Store, hostID string, wantN int, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if countPendingForHost(t, st, hostID) == wantN { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Errorf("pending count for host %s: want %d after %v, got %d", + hostID, wantN, timeout, countPendingForHost(t, st, hostID)) +} + // countAuditAction returns the number of audit_log rows with the given action. func countAuditAction(t *testing.T, st *store.Store, action string) int { t.Helper() @@ -164,10 +182,11 @@ func TestDrainPendingAbandonsOnRetryMax(t *testing.T) { sendHello(t, c, "abandon-retry-host") _ = drainUntil(t, c, api.MsgScheduleSet) - // Call DrainPending directly — gives us deterministic completion. - conn := connFromHub(t, srv, hostID) - _ = conn // just to ensure conn was registered - srv.DrainPending(context.Background(), hostID) + // The on-hello goroutine processes the row (retry_max exceeded → abandon). + // Wait for it to finish rather than calling DrainPending directly, which + // would be a no-op while the goroutine holds the per-host drain mutex. + _ = connFromHub(t, srv, hostID) // ensure hub registration + waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after abandon: got %d, want 0", n) @@ -309,7 +328,10 @@ func TestDrainPendingDropsRowsForGoneSchedule(t *testing.T) { sendHello(t, c, "gone-sched-host") _ = drainUntil(t, c, api.MsgScheduleSet) - srv.DrainPending(context.Background(), hostID) + // The on-hello goroutine processes the row (disabled schedule → abandon). + // Poll for completion instead of calling DrainPending, which would return + // immediately while the goroutine holds the per-host drain mutex. + waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after schedule-gone abandon: got %d, want 0", n) @@ -387,7 +409,10 @@ func TestDrainPendingDropsRowsForGoneSourceGroup(t *testing.T) { sendHello(t, c, "gone-group-host") _ = drainUntil(t, c, api.MsgScheduleSet) - srv.DrainPending(context.Background(), hostID) + // The on-hello goroutine processes the row (source group gone → abandon). + // Poll for completion instead of calling DrainPending, which would return + // immediately while the goroutine holds the per-host drain mutex. + waitForPendingCount(t, st, hostID, 0, 2*time.Second) if n := countPendingForHost(t, st, hostID); n != 0 { t.Errorf("pending rows after source-group-gone abandon: got %d, want 0", n) @@ -469,3 +494,74 @@ func TestEnqueueOnDispatchFailure(t *testing.T) { t.Errorf("last_error empty") } } + +// TestDrainPendingSerializesPerHost verifies that concurrent DrainPending +// calls for the same host do not double-dispatch pending rows. The per-host +// mutex (TryLock semantics) means exactly one drain processes each row. +func TestDrainPendingSerializesPerHost(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "serialize-host") + gid, sid := seedSchedAndGroup(t, st, hostID, 10) + + // Connect the agent so DrainPending can dispatch. + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "serialize-host") + // Drain the on-hello goroutine's pass first (no pending rows yet), + // then wait for the schedule.set so the connection is fully settled. + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Insert 5 pending rows now that the on-hello drain has already run. + now := time.Now().UTC() + for i := range 5 { + pid := ulid.Make().String() + if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{ + ID: pid, + ScheduleID: sid, + SourceGroupID: gid, + HostID: hostID, + Attempt: 1, + NextAttemptAt: now.Add(-time.Second), + ScheduledAt: now.Add(-time.Duration(i+1) * time.Minute), + }); err != nil { + t.Fatalf("enqueue row %d: %v", i, err) + } + } + + // Spawn 10 goroutines all calling DrainPending concurrently. + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + srv.DrainPending(context.Background(), hostID) + }() + } + wg.Wait() + + // Drain any envelopes the agent received so we don't block below. + // We read with short timeouts and stop when the connection goes quiet. + drainDeadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(drainDeadline) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, _, err := c.Read(ctx) + cancel() + if err != nil { + break + } + } + + // All 5 pending rows must be gone. + if n := countPendingForHost(t, st, hostID); n != 0 { + t.Errorf("pending rows after concurrent drain: got %d, want 0", n) + } + + // Exactly 5 backup job rows (one per pending row), not 10+ from a race. + var n int + _ = st.DB().QueryRow( + `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule'`, + hostID).Scan(&n) + if n != 5 { + t.Errorf("backup job rows: got %d, want 5 (per-host mutex must prevent double-dispatch)", n) + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index 51d3774..8ef3d83 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -7,6 +7,7 @@ import ( "context" "errors" stdhttp "net/http" + "sync" "time" "github.com/go-chi/chi/v5" @@ -41,6 +42,13 @@ type Deps struct { type Server struct { srv *stdhttp.Server deps Deps + + // drainLocks serializes DrainPending per host. The on-hello + // goroutine and the 30s ticker can otherwise race for the same + // host, double-dispatching every pending row. Map of hostID → + // sync.Mutex; checked-and-locked atomically via drainLocksMu. + drainLocksMu sync.Mutex + drainLocks map[string]*sync.Mutex } // New builds a configured but not-yet-started server. @@ -59,7 +67,7 @@ func New(deps Deps) *Server { w.WriteHeader(stdhttp.StatusNoContent) }) - s := &Server{deps: deps} + s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)} s.routes(r) s.srv = &stdhttp.Server{