server: serialize DrainPending per host (avoid drain double-dispatch)
Add a per-host drain mutex (drainLocks map guarded by drainLocksMu) on the Server struct. DrainPending acquires it with TryLock: if a drain is already in-flight for this host, the call returns immediately — the running drain will see every pending row. This prevents the on-hello goroutine and the 30s tick from both listing the same host's rows and dispatching them twice. Update three existing tests that called srv.DrainPending explicitly after the on-hello goroutine had already been spawned: replace the now-redundant direct call with a waitForPendingCount poll so they don't race the goroutine's mutex ownership. Add TestDrainPendingSerializesPerHost which fires 10 concurrent DrainPending goroutines against a 5-row queue and asserts exactly 5 job rows result.
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oklog/ulid/v2"
|
"github.com/oklog/ulid/v2"
|
||||||
@@ -34,7 +35,18 @@ const (
|
|||||||
// row is dropped (audit-logged as abandoned). Otherwise we attempt
|
// row is dropped (audit-logged as abandoned). Otherwise we attempt
|
||||||
// dispatch; success deletes the row, failure bumps the attempt and
|
// dispatch; success deletes the row, failure bumps the attempt and
|
||||||
// reschedules with exponential backoff.
|
// reschedules with exponential backoff.
|
||||||
|
//
|
||||||
|
// A per-host mutex (hostDrainMutex) ensures that the on-hello goroutine
|
||||||
|
// and the 30s tick cannot process the same host concurrently. If a drain
|
||||||
|
// is already in-flight for this host, the call returns immediately — the
|
||||||
|
// running drain will see any rows we'd have processed.
|
||||||
func (s *Server) DrainPending(ctx context.Context, hostID string) {
|
func (s *Server) DrainPending(ctx context.Context, hostID string) {
|
||||||
|
mu := s.hostDrainMutex(hostID)
|
||||||
|
if !mu.TryLock() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
runs, err := s.deps.Store.ListPendingRunsForHost(ctx, hostID)
|
runs, err := s.deps.Store.ListPendingRunsForHost(ctx, hostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("drain pending: list", "host_id", hostID, "err", err)
|
slog.Warn("drain pending: list", "host_id", hostID, "err", err)
|
||||||
@@ -148,6 +160,24 @@ func (s *Server) abandonPending(ctx context.Context, p store.PendingRun, reason
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hostDrainMutex returns the per-host mutex for DrainPending,
|
||||||
|
// creating it on first request. The map is guarded by drainLocksMu.
|
||||||
|
// Mutex objects are never deleted from the map — there are at most
|
||||||
|
// len(hosts) entries, which is bounded by the fleet size.
|
||||||
|
func (s *Server) hostDrainMutex(hostID string) *sync.Mutex {
|
||||||
|
s.drainLocksMu.Lock()
|
||||||
|
defer s.drainLocksMu.Unlock()
|
||||||
|
if s.drainLocks == nil {
|
||||||
|
s.drainLocks = make(map[string]*sync.Mutex)
|
||||||
|
}
|
||||||
|
mu, ok := s.drainLocks[hostID]
|
||||||
|
if !ok {
|
||||||
|
mu = &sync.Mutex{}
|
||||||
|
s.drainLocks[hostID] = mu
|
||||||
|
}
|
||||||
|
return mu
|
||||||
|
}
|
||||||
|
|
||||||
// DrainAllDue is the 30s-ticker entrypoint. Walks rows whose
|
// DrainAllDue is the 30s-ticker entrypoint. Walks rows whose
|
||||||
// next_attempt_at <= now (DuePendingRuns), dedupes by host, and calls
|
// next_attempt_at <= now (DuePendingRuns), dedupes by host, and calls
|
||||||
// DrainPending per host. The DrainPending then re-walks the host's
|
// DrainPending per host. The DrainPending then re-walks the host's
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -57,6 +58,23 @@ func countPendingForHost(t *testing.T, st *store.Store, hostID string) int {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForPendingCount polls until the pending_runs count for hostID
|
||||||
|
// reaches wantN or the deadline expires. Use this instead of calling
|
||||||
|
// DrainPending synchronously when the test relies on the on-hello
|
||||||
|
// goroutine (which holds the per-host drain mutex) to process rows.
|
||||||
|
func waitForPendingCount(t *testing.T, st *store.Store, hostID string, wantN int, timeout time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if countPendingForHost(t, st, hostID) == wantN {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Errorf("pending count for host %s: want %d after %v, got %d",
|
||||||
|
hostID, wantN, timeout, countPendingForHost(t, st, hostID))
|
||||||
|
}
|
||||||
|
|
||||||
// countAuditAction returns the number of audit_log rows with the given action.
|
// countAuditAction returns the number of audit_log rows with the given action.
|
||||||
func countAuditAction(t *testing.T, st *store.Store, action string) int {
|
func countAuditAction(t *testing.T, st *store.Store, action string) int {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
@@ -164,10 +182,11 @@ func TestDrainPendingAbandonsOnRetryMax(t *testing.T) {
|
|||||||
sendHello(t, c, "abandon-retry-host")
|
sendHello(t, c, "abandon-retry-host")
|
||||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
// Call DrainPending directly — gives us deterministic completion.
|
// The on-hello goroutine processes the row (retry_max exceeded → abandon).
|
||||||
conn := connFromHub(t, srv, hostID)
|
// Wait for it to finish rather than calling DrainPending directly, which
|
||||||
_ = conn // just to ensure conn was registered
|
// would be a no-op while the goroutine holds the per-host drain mutex.
|
||||||
srv.DrainPending(context.Background(), hostID)
|
_ = connFromHub(t, srv, hostID) // ensure hub registration
|
||||||
|
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
|
||||||
|
|
||||||
if n := countPendingForHost(t, st, hostID); n != 0 {
|
if n := countPendingForHost(t, st, hostID); n != 0 {
|
||||||
t.Errorf("pending rows after abandon: got %d, want 0", n)
|
t.Errorf("pending rows after abandon: got %d, want 0", n)
|
||||||
@@ -309,7 +328,10 @@ func TestDrainPendingDropsRowsForGoneSchedule(t *testing.T) {
|
|||||||
sendHello(t, c, "gone-sched-host")
|
sendHello(t, c, "gone-sched-host")
|
||||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
srv.DrainPending(context.Background(), hostID)
|
// The on-hello goroutine processes the row (disabled schedule → abandon).
|
||||||
|
// Poll for completion instead of calling DrainPending, which would return
|
||||||
|
// immediately while the goroutine holds the per-host drain mutex.
|
||||||
|
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
|
||||||
|
|
||||||
if n := countPendingForHost(t, st, hostID); n != 0 {
|
if n := countPendingForHost(t, st, hostID); n != 0 {
|
||||||
t.Errorf("pending rows after schedule-gone abandon: got %d, want 0", n)
|
t.Errorf("pending rows after schedule-gone abandon: got %d, want 0", n)
|
||||||
@@ -387,7 +409,10 @@ func TestDrainPendingDropsRowsForGoneSourceGroup(t *testing.T) {
|
|||||||
sendHello(t, c, "gone-group-host")
|
sendHello(t, c, "gone-group-host")
|
||||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
srv.DrainPending(context.Background(), hostID)
|
// The on-hello goroutine processes the row (source group gone → abandon).
|
||||||
|
// Poll for completion instead of calling DrainPending, which would return
|
||||||
|
// immediately while the goroutine holds the per-host drain mutex.
|
||||||
|
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
|
||||||
|
|
||||||
if n := countPendingForHost(t, st, hostID); n != 0 {
|
if n := countPendingForHost(t, st, hostID); n != 0 {
|
||||||
t.Errorf("pending rows after source-group-gone abandon: got %d, want 0", n)
|
t.Errorf("pending rows after source-group-gone abandon: got %d, want 0", n)
|
||||||
@@ -469,3 +494,74 @@ func TestEnqueueOnDispatchFailure(t *testing.T) {
|
|||||||
t.Errorf("last_error empty")
|
t.Errorf("last_error empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDrainPendingSerializesPerHost verifies that concurrent DrainPending
|
||||||
|
// calls for the same host do not double-dispatch pending rows. The per-host
|
||||||
|
// mutex (TryLock semantics) means exactly one drain processes each row.
|
||||||
|
func TestDrainPendingSerializesPerHost(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
srv, ts, st := rawTestServer(t)
|
||||||
|
hostID, token := enrolHostForWS(t, srv, st, "serialize-host")
|
||||||
|
gid, sid := seedSchedAndGroup(t, st, hostID, 10)
|
||||||
|
|
||||||
|
// Connect the agent so DrainPending can dispatch.
|
||||||
|
c := agentDial(t, srv, ts, hostID, token)
|
||||||
|
sendHello(t, c, "serialize-host")
|
||||||
|
// Drain the on-hello goroutine's pass first (no pending rows yet),
|
||||||
|
// then wait for the schedule.set so the connection is fully settled.
|
||||||
|
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||||
|
|
||||||
|
// Insert 5 pending rows now that the on-hello drain has already run.
|
||||||
|
now := time.Now().UTC()
|
||||||
|
for i := range 5 {
|
||||||
|
pid := ulid.Make().String()
|
||||||
|
if err := st.EnqueuePendingRun(context.Background(), &store.PendingRun{
|
||||||
|
ID: pid,
|
||||||
|
ScheduleID: sid,
|
||||||
|
SourceGroupID: gid,
|
||||||
|
HostID: hostID,
|
||||||
|
Attempt: 1,
|
||||||
|
NextAttemptAt: now.Add(-time.Second),
|
||||||
|
ScheduledAt: now.Add(-time.Duration(i+1) * time.Minute),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("enqueue row %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spawn 10 goroutines all calling DrainPending concurrently.
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for range 10 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
srv.DrainPending(context.Background(), hostID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Drain any envelopes the agent received so we don't block below.
|
||||||
|
// We read with short timeouts and stop when the connection goes quiet.
|
||||||
|
drainDeadline := time.Now().Add(500 * time.Millisecond)
|
||||||
|
for time.Now().Before(drainDeadline) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
_, _, err := c.Read(ctx)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All 5 pending rows must be gone.
|
||||||
|
if n := countPendingForHost(t, st, hostID); n != 0 {
|
||||||
|
t.Errorf("pending rows after concurrent drain: got %d, want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exactly 5 backup job rows (one per pending row), not 10+ from a race.
|
||||||
|
var n int
|
||||||
|
_ = st.DB().QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule'`,
|
||||||
|
hostID).Scan(&n)
|
||||||
|
if n != 5 {
|
||||||
|
t.Errorf("backup job rows: got %d, want 5 (per-host mutex must prevent double-dispatch)", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
stdhttp "net/http"
|
stdhttp "net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@@ -41,6 +42,13 @@ type Deps struct {
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
srv *stdhttp.Server
|
srv *stdhttp.Server
|
||||||
deps Deps
|
deps Deps
|
||||||
|
|
||||||
|
// drainLocks serializes DrainPending per host. The on-hello
|
||||||
|
// goroutine and the 30s ticker can otherwise race for the same
|
||||||
|
// host, double-dispatching every pending row. Map of hostID →
|
||||||
|
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
|
||||||
|
drainLocksMu sync.Mutex
|
||||||
|
drainLocks map[string]*sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New builds a configured but not-yet-started server.
|
// New builds a configured but not-yet-started server.
|
||||||
@@ -59,7 +67,7 @@ func New(deps Deps) *Server {
|
|||||||
w.WriteHeader(stdhttp.StatusNoContent)
|
w.WriteHeader(stdhttp.StatusNoContent)
|
||||||
})
|
})
|
||||||
|
|
||||||
s := &Server{deps: deps}
|
s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)}
|
||||||
s.routes(r)
|
s.routes(r)
|
||||||
|
|
||||||
s.srv = &stdhttp.Server{
|
s.srv = &stdhttp.Server{
|
||||||
|
|||||||
Reference in New Issue
Block a user