diff --git a/internal/server/http/p2r01_ws_test.go b/internal/server/http/p2r01_ws_test.go new file mode 100644 index 0000000..5555f9d --- /dev/null +++ b/internal/server/http/p2r01_ws_test.go @@ -0,0 +1,402 @@ +// p2r01_ws_test.go — integration tests for the WS-touching pieces of +// P2R-01: auto-init dispatch on hello, and dispatchScheduledJob's +// schedule.fire → command.run-per-group resolution. +package http + +import ( + "context" + "encoding/json" + "net/http/httptest" + stdhttp "net/http" + "strings" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// agentDial connects to the server's /ws/agent as a freshly-enrolled +// host and returns the conn + a cleanup. Caller is expected to send +// hello. +func agentDial(t *testing.T, srv *Server, ts *httptest.Server, hostID, token string) *websocket.Conn { + t.Helper() + url := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent" + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, + }) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = c.CloseNow() }) + return c +} + +// readEnvelope blocks until one envelope arrives or the test times out. +func readEnvelope(t *testing.T, c *websocket.Conn) api.Envelope { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + mt, raw, err := c.Read(ctx) + if err != nil { + t.Fatalf("ws read: %v", err) + } + if mt != websocket.MessageText { + t.Fatalf("ws read: non-text frame %v", mt) + } + var env api.Envelope + if err := json.Unmarshal(raw, &env); err != nil { + t.Fatalf("ws unmarshal: %v: %s", err, raw) + } + return env +} + +// drainUntil reads envelopes until a wantType arrives or the test +// times out. Returns the matching envelope (and the others read along +// the way, ignored). On-hello pushes config.update + schedule.set + +// (sometimes) command.run; tests want to skip past the prefix to the +// envelope they care about. +func drainUntil(t *testing.T, c *websocket.Conn, wantType api.MessageType) api.Envelope { + t.Helper() + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + env := readEnvelope(t, c) + if env.Type == wantType { + return env + } + } + t.Fatalf("timed out waiting for %s", wantType) + return api.Envelope{} +} + +// enrolHostForWS pre-enrols a host with bound repo creds so the server +// will treat it as ready to receive command.run. +func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (hostID, token string) { + t.Helper() + hostID = ulid.Make().String() + token, _ = auth.NewToken() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: name, OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, auth.HashToken(token), ""); err != nil { + t.Fatalf("create host: %v", err) + } + enc, err := srv.encryptRepoCreds(repoCredsBlob{ + RepoURL: "rest:http://r/x", RepoUsername: "u", RepoPassword: "p", + }, []byte("host:"+hostID)) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + if err := st.SetHostCredentials(context.Background(), hostID, enc); err != nil { + t.Fatalf("set creds: %v", err) + } + return +} + +func sendHello(t *testing.T, c *websocket.Conn, hostname string) { + t.Helper() + env, _ := api.Marshal(api.MsgHello, "", api.HelloPayload{ + ProtocolVersion: api.CurrentProtocolVersion, + AgentVersion: "test", + ResticVersion: "0.17", + Hostname: hostname, + OS: api.OSLinux, Arch: api.ArchAmd64, + }) + raw, _ := json.Marshal(env) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := c.Write(ctx, websocket.MessageText, raw); err != nil { + t.Fatalf("write hello: %v", err) + } +} + +// rawTestServer wires Server up against an httptest server and returns +// the inner Server pointer + the URL. +func rawTestServer(t *testing.T) (*Server, *httptest.Server, *store.Store) { + t.Helper() + srv, _, st := newTestServerWithHub(t) + ts := httptest.NewServer(srv.srv.Handler) + t.Cleanup(ts.Close) + return srv, ts, st +} + +// connFromHub fetches the live *ws.Conn for hostID from the hub. +// Polls briefly because the WS handler registers the conn just after +// the OnHello callback returns. +func connFromHub(t *testing.T, srv *Server, hostID string) *ws.Conn { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if c := srv.deps.Hub.Conn(hostID); c != nil { + return c + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("hub never registered conn for %s", hostID) + return nil +} + +// ----- auto-init dispatch ----------------------------------------- + +func TestAutoInitDispatchedOnFirstHelloOnly(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "auto-init-host") + + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "auto-init-host") + + // Expect config.update + schedule.set + command.run(init) in some + // order. drainUntil walks past the first two to find the init. + env := drainUntil(t, c, api.MsgCommandRun) + var p api.CommandRunPayload + if err := env.UnmarshalPayload(&p); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if p.Kind != api.JobInit { + t.Fatalf("first command.run kind: %s, want init", p.Kind) + } + + // Mark the init job succeeded so HasJobOfKind sees terminal state. + if err := st.MarkJobFinished(context.Background(), p.JobID, "succeeded", 0, nil, "", time.Now().UTC()); err != nil { + t.Fatalf("mark finished: %v", err) + } + + // Reconnect — the second hello must NOT dispatch another init. + _ = c.Close(websocket.StatusNormalClosure, "test") + // Brief wait so the hub unregisters the old conn before we open a + // new one (otherwise Register supersedes the old one, which is a + // race the production code already handles but the test doesn't + // need to fight). + time.Sleep(50 * time.Millisecond) + + c2 := agentDial(t, srv, ts, hostID, token) + sendHello(t, c2, "auto-init-host") + // Expect config.update + schedule.set, then a quiet read that + // times out — no second init. + deadline := time.Now().Add(1500 * time.Millisecond) + for time.Now().Before(deadline) { + ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) + mt, raw, err := c2.Read(ctx) + cancel() + if err != nil { + break + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + if err := json.Unmarshal(raw, &env); err != nil { + continue + } + if env.Type == api.MsgCommandRun { + var p api.CommandRunPayload + _ = env.UnmarshalPayload(&p) + if p.Kind == api.JobInit { + t.Fatalf("second hello re-dispatched init (job_id=%s) — gate is broken", p.JobID) + } + } + } +} + +func TestAutoInitSkippedWhenNoCreds(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID := ulid.Make().String() + token, _ := auth.NewToken() + if err := st.CreateHost(context.Background(), store.Host{ + ID: hostID, Name: "no-creds-host", OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, auth.HashToken(token), ""); err != nil { + t.Fatalf("create host: %v", err) + } + + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "no-creds-host") + + // On-hello sends a schedule.set (config.update is skipped because + // no creds). We should NOT see a command.run(init). + deadline := time.Now().Add(1500 * time.Millisecond) + for time.Now().Before(deadline) { + ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) + mt, raw, err := c.Read(ctx) + cancel() + if err != nil { + break + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + _ = json.Unmarshal(raw, &env) + if env.Type == api.MsgCommandRun { + var p api.CommandRunPayload + _ = env.UnmarshalPayload(&p) + t.Fatalf("auto-init dispatched without creds (kind=%s)", p.Kind) + } + } +} + +// ----- dispatchScheduledJob --------------------------------------- + +func TestDispatchScheduledJobIteratesGroups(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "fire-host") + + // Create two source groups + one schedule covering both. + g1 := ulid.Make().String() + g2 := ulid.Make().String() + for _, g := range []*store.SourceGroup{ + {ID: g1, HostID: hostID, Name: "etc", Includes: []string{"/etc"}}, + {ID: g2, HostID: hostID, Name: "home", Includes: []string{"/home"}}, + } { + if err := st.CreateSourceGroup(context.Background(), g); err != nil { + t.Fatalf("group: %v", err) + } + } + sid := ulid.Make().String() + if err := st.CreateSchedule(context.Background(), &store.Schedule{ + ID: sid, HostID: hostID, + CronExpr: "0 3 * * *", Enabled: true, + SourceGroupIDs: []string{g1, g2}, + }); err != nil { + t.Fatalf("schedule: %v", err) + } + + // Mark a successful init job up front so the auto-init path + // doesn't fire and pollute the envelope sequence we're measuring. + if err := st.CreateJob(context.Background(), store.Job{ + ID: ulid.Make().String(), HostID: hostID, Kind: "init", + ActorKind: "system", CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("seed init: %v", err) + } + + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "fire-host") + + // Wait for the schedule.set push so we know the hub is connected. + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Resolve the conn from the hub and call dispatchScheduledJob + // directly — same path the WS handler invokes on schedule.fire. + conn := connFromHub(t, srv, hostID) + srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC()) + + // Two backups should be queued, one per group. Read both. + got := map[string]api.CommandRunPayload{} + deadline := time.Now().Add(3 * time.Second) + for len(got) < 2 && time.Now().Before(deadline) { + ctx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond) + mt, raw, err := c.Read(ctx) + cancel() + if err != nil { + break + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + _ = json.Unmarshal(raw, &env) + if env.Type != api.MsgCommandRun { + continue + } + var p api.CommandRunPayload + _ = env.UnmarshalPayload(&p) + if p.Kind != api.JobBackup { + continue + } + got[p.Tag] = p + } + if len(got) != 2 { + t.Fatalf("expected 2 backups (one per group), got %d: %+v", len(got), got) + } + if !equalStrings(got["etc"].Includes, []string{"/etc"}) { + t.Errorf("etc backup includes: %v", got["etc"].Includes) + } + if !equalStrings(got["home"].Includes, []string{"/home"}) { + t.Errorf("home backup includes: %v", got["home"].Includes) + } + + // Two job rows should exist for this host with kind=backup, + // actor_kind=schedule, scheduled_id=sid. + var n int + if err := st.DB().QueryRow( + `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule' AND scheduled_id = ?`, + hostID, sid).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 2 { + t.Errorf("scheduled backup jobs: got %d, want 2", n) + } +} + +func TestDispatchScheduledJobDisabledNoOp(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "disabled-host") + + gid := ulid.Make().String() + if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{ + ID: gid, HostID: hostID, Name: "default", Includes: []string{"/etc"}, + }); err != nil { + t.Fatalf("group: %v", err) + } + sid := ulid.Make().String() + if err := st.CreateSchedule(context.Background(), &store.Schedule{ + ID: sid, HostID: hostID, + CronExpr: "0 3 * * *", Enabled: false, // disabled + SourceGroupIDs: []string{gid}, + }); err != nil { + t.Fatalf("schedule: %v", err) + } + if err := st.CreateJob(context.Background(), store.Job{ + ID: ulid.Make().String(), HostID: hostID, Kind: "init", + ActorKind: "system", CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("seed init: %v", err) + } + + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "disabled-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + + conn := connFromHub(t, srv, hostID) + srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC()) + + // No backups should be queued. + deadline := time.Now().Add(800 * time.Millisecond) + for time.Now().Before(deadline) { + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + mt, raw, err := c.Read(ctx) + cancel() + if err != nil { + break + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + _ = json.Unmarshal(raw, &env) + if env.Type == api.MsgCommandRun { + var p api.CommandRunPayload + _ = env.UnmarshalPayload(&p) + if p.Kind == api.JobBackup { + t.Fatalf("disabled schedule still dispatched a backup: %+v", p) + } + } + } + var n int + _ = st.DB().QueryRow(`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup'`, hostID).Scan(&n) + if n != 0 { + t.Errorf("disabled schedule produced %d backup rows", n) + } +} diff --git a/internal/server/http/run_group.go b/internal/server/http/run_group.go index 1f16095..1a0f35c 100644 --- a/internal/server/http/run_group.go +++ b/internal/server/http/run_group.go @@ -7,7 +7,6 @@ package http import ( - "encoding/json" "errors" stdhttp "net/http" @@ -41,13 +40,13 @@ func (s *Server) handleRunSourceGroup(w stdhttp.ResponseWriter, r *stdhttp.Reque return } - retention, _ := json.Marshal(g.RetentionPolicy) + // Backup invocations don't consume RetentionPolicy — that lives on + // forget. Sending the resolved set here would just be dead weight. res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobBackup, api.CommandRunPayload{ - Includes: g.Includes, - Excludes: g.Excludes, - Tag: g.Name, - RetentionPolicy: retention, + Includes: g.Includes, + Excludes: g.Excludes, + Tag: g.Name, }) if code != "" { s.runGroupError(w, r, status, code, msg) diff --git a/internal/server/http/schedule_push.go b/internal/server/http/schedule_push.go index 7196255..6c2c692 100644 --- a/internal/server/http/schedule_push.go +++ b/internal/server/http/schedule_push.go @@ -183,14 +183,15 @@ func (s *Server) dispatchBackupForGroup(ctx context.Context, conn *ws.Conn, host "schedule_id", scheduleID, "group", g.Name, "err", err) return } - retention, _ := json.Marshal(g.RetentionPolicy) + // Backup ignores RetentionPolicy — the forget cadence lives on + // host_repo_maintenance and is driven by the server-side ticker + // (P2R-06). Don't ship the field on backup dispatches. env, err := api.Marshal(api.MsgCommandRun, jobID, api.CommandRunPayload{ - JobID: jobID, - Kind: api.JobBackup, - Includes: g.Includes, - Excludes: g.Excludes, - Tag: g.Name, - RetentionPolicy: retention, + JobID: jobID, + Kind: api.JobBackup, + Includes: g.Includes, + Excludes: g.Excludes, + Tag: g.Name, }) if err != nil { slog.Warn("schedule.fire: marshal command.run", diff --git a/internal/server/ws/hub.go b/internal/server/ws/hub.go index c0b64ef..c10a85d 100644 --- a/internal/server/ws/hub.go +++ b/internal/server/ws/hub.go @@ -81,6 +81,17 @@ func (h *Hub) Connected(hostID string) bool { return ok } +// Conn returns the canonical connection for hostID, or nil if the +// host is offline. Tests use this to obtain a *Conn for direct calls +// into handlers that take one. Production code should prefer Send, +// which avoids holding a reference past the point where a supersede +// might have replaced the conn. +func (h *Hub) Conn(hostID string) *Conn { + h.mu.RLock() + defer h.mu.RUnlock() + return h.conns[hostID] +} + // ----- Conn methods -------------------------------------------------- // NewConn wraps a freshly-accepted websocket for a given hostID.