// p2r01_ws_test.go — integration tests for the WS-touching pieces of // P2R-01: auto-init dispatch on hello, and dispatchScheduledJob's // schedule.fire → command.run-per-group resolution. package http import ( "context" "encoding/json" stdhttp "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/coder/websocket" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // agentDial connects to the server's /ws/agent as a freshly-enrolled // host and returns the conn + a cleanup. Caller is expected to send // hello. func agentDial(t *testing.T, srv *Server, ts *httptest.Server, hostID, token string) *websocket.Conn { t.Helper() url := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent" ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, }) if err != nil { t.Fatalf("dial: %v", err) } t.Cleanup(func() { _ = c.CloseNow() if res != nil && res.Body != nil { _ = res.Body.Close() } }) return c } // readEnvelope blocks until one envelope arrives or the test times out. func readEnvelope(t *testing.T, c *websocket.Conn) api.Envelope { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() mt, raw, err := c.Read(ctx) if err != nil { t.Fatalf("ws read: %v", err) } if mt != websocket.MessageText { t.Fatalf("ws read: non-text frame %v", mt) } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { t.Fatalf("ws unmarshal: %v: %s", err, raw) } return env } // drainUntil reads envelopes until a wantType arrives or the test // times out. Returns the matching envelope (and the others read along // the way, ignored). On-hello pushes config.update + schedule.set + // (sometimes) command.run; tests want to skip past the prefix to the // envelope they care about. func drainUntil(t *testing.T, c *websocket.Conn, wantType api.MessageType) api.Envelope { t.Helper() deadline := time.Now().Add(3 * time.Second) for time.Now().Before(deadline) { env := readEnvelope(t, c) if env.Type == wantType { return env } } t.Fatalf("timed out waiting for %s", wantType) return api.Envelope{} } // enrolHostForWS pre-enrolls a host with bound repo creds so the server // will treat it as ready to receive command.run. func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (hostID, token string) { t.Helper() hostID = ulid.Make().String() token, _ = auth.NewToken() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: name, OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, auth.HashToken(token), ""); err != nil { t.Fatalf("create host: %v", err) } enc, err := srv.encryptRepoCreds(repoCredsBlob{ RepoURL: "rest:http://r/x", RepoUsername: "u", RepoPassword: "p", }, []byte("host:"+hostID)) if err != nil { t.Fatalf("encrypt: %v", err) } if err := st.SetHostCredentials(context.Background(), hostID, enc); err != nil { t.Fatalf("set creds: %v", err) } return } func sendHello(t *testing.T, c *websocket.Conn, hostname string) { t.Helper() env, _ := api.Marshal(api.MsgHello, "", api.HelloPayload{ ProtocolVersion: api.CurrentProtocolVersion, AgentVersion: "test", ResticVersion: "0.17", Hostname: hostname, OS: api.OSLinux, Arch: api.ArchAmd64, }) raw, _ := json.Marshal(env) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := c.Write(ctx, websocket.MessageText, raw); err != nil { t.Fatalf("write hello: %v", err) } } // rawTestServer wires Server up against an httptest server and returns // the inner Server pointer + the URL. func rawTestServer(t *testing.T) (*Server, *httptest.Server, *store.Store) { t.Helper() srv, _, st := newTestServerWithHub(t) ts := httptest.NewServer(srv.srv.Handler) t.Cleanup(ts.Close) return srv, ts, st } // connFromHub fetches the live *ws.Conn for hostID from the hub. // Polls briefly because the WS handler registers the conn just after // the OnHello callback returns. func connFromHub(t *testing.T, srv *Server, hostID string) *ws.Conn { t.Helper() deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { if c := srv.deps.Hub.Conn(hostID); c != nil { return c } time.Sleep(20 * time.Millisecond) } t.Fatalf("hub never registered conn for %s", hostID) return nil } // ----- auto-init dispatch ----------------------------------------- func TestAutoInitDispatchedOnFirstHelloOnly(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "auto-init-host") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "auto-init-host") // Expect config.update + schedule.set + command.run(init) in some // order. drainUntil walks past the first two to find the init. env := drainUntil(t, c, api.MsgCommandRun) var p api.CommandRunPayload if err := env.UnmarshalPayload(&p); err != nil { t.Fatalf("unmarshal: %v", err) } if p.Kind != api.JobInit { t.Fatalf("first command.run kind: %s, want init", p.Kind) } // Mark the init job succeeded so HasJobOfKind sees terminal state. if err := st.MarkJobFinished(context.Background(), p.JobID, "succeeded", 0, nil, "", time.Now().UTC()); err != nil { t.Fatalf("mark finished: %v", err) } // Reconnect — the second hello must NOT dispatch another init. _ = c.Close(websocket.StatusNormalClosure, "test") // Brief wait so the hub unregisters the old conn before we open a // new one (otherwise Register supersedes the old one, which is a // race the production code already handles but the test doesn't // need to fight). time.Sleep(50 * time.Millisecond) c2 := agentDial(t, srv, ts, hostID, token) sendHello(t, c2, "auto-init-host") // Expect config.update + schedule.set, then a quiet read that // times out — no second init. deadline := time.Now().Add(1500 * time.Millisecond) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) mt, raw, err := c2.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { continue } if env.Type == api.MsgCommandRun { var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind == api.JobInit { t.Fatalf("second hello re-dispatched init (job_id=%s) — gate is broken", p.JobID) } } } } func TestAutoInitSkippedWhenNoCreds(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID := ulid.Make().String() token, _ := auth.NewToken() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: "no-creds-host", OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, auth.HashToken(token), ""); err != nil { t.Fatalf("create host: %v", err) } c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "no-creds-host") // On-hello sends a schedule.set (config.update is skipped because // no creds). We should NOT see a command.run(init). deadline := time.Now().Add(1500 * time.Millisecond) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope _ = json.Unmarshal(raw, &env) if env.Type == api.MsgCommandRun { var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) t.Fatalf("auto-init dispatched without creds (kind=%s)", p.Kind) } } } // ----- dispatchScheduledJob --------------------------------------- func TestDispatchScheduledJobIteratesGroups(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "fire-host") // Create two source groups + one schedule covering both. g1 := ulid.Make().String() g2 := ulid.Make().String() for _, g := range []*store.SourceGroup{ {ID: g1, HostID: hostID, Name: "etc", Includes: []string{"/etc"}}, {ID: g2, HostID: hostID, Name: "home", Includes: []string{"/home"}}, } { if err := st.CreateSourceGroup(context.Background(), g); err != nil { t.Fatalf("group: %v", err) } } sid := ulid.Make().String() if err := st.CreateSchedule(context.Background(), &store.Schedule{ ID: sid, HostID: hostID, CronExpr: "0 3 * * *", Enabled: true, SourceGroupIDs: []string{g1, g2}, }); err != nil { t.Fatalf("schedule: %v", err) } // Mark a successful init job up front so the auto-init path // doesn't fire and pollute the envelope sequence we're measuring. if err := st.CreateJob(context.Background(), store.Job{ ID: ulid.Make().String(), HostID: hostID, Kind: "init", ActorKind: "system", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed init: %v", err) } c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "fire-host") // Wait for the schedule.set push so we know the hub is connected. _ = drainUntil(t, c, api.MsgScheduleSet) // Resolve the conn from the hub and call dispatchScheduledJob // directly — same path the WS handler invokes on schedule.fire. conn := connFromHub(t, srv, hostID) srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC()) // Two backups should be queued, one per group. Read both. got := map[string]api.CommandRunPayload{} deadline := time.Now().Add(3 * time.Second) for len(got) < 2 && time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope _ = json.Unmarshal(raw, &env) if env.Type != api.MsgCommandRun { continue } var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind != api.JobBackup { continue } got[p.Tag] = p } if len(got) != 2 { t.Fatalf("expected 2 backups (one per group), got %d: %+v", len(got), got) } if !equalStrings(got["etc"].Includes, []string{"/etc"}) { t.Errorf("etc backup includes: %v", got["etc"].Includes) } if !equalStrings(got["home"].Includes, []string{"/home"}) { t.Errorf("home backup includes: %v", got["home"].Includes) } // Two job rows should exist for this host with kind=backup, // actor_kind=schedule, scheduled_id=sid. var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule' AND scheduled_id = ?`, hostID, sid).Scan(&n); err != nil { t.Fatalf("count: %v", err) } if n != 2 { t.Errorf("scheduled backup jobs: got %d, want 2", n) } } func TestDispatchScheduledJobDisabledNoOp(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) hostID, token := enrolHostForWS(t, srv, st, "disabled-host") gid := ulid.Make().String() if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{ ID: gid, HostID: hostID, Name: "default", Includes: []string{"/etc"}, }); err != nil { t.Fatalf("group: %v", err) } sid := ulid.Make().String() if err := st.CreateSchedule(context.Background(), &store.Schedule{ ID: sid, HostID: hostID, CronExpr: "0 3 * * *", Enabled: false, // disabled SourceGroupIDs: []string{gid}, }); err != nil { t.Fatalf("schedule: %v", err) } if err := st.CreateJob(context.Background(), store.Job{ ID: ulid.Make().String(), HostID: hostID, Kind: "init", ActorKind: "system", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed init: %v", err) } c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "disabled-host") _ = drainUntil(t, c, api.MsgScheduleSet) conn := connFromHub(t, srv, hostID) srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC()) // No backups should be queued. deadline := time.Now().Add(800 * time.Millisecond) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) mt, raw, err := c.Read(ctx) cancel() if err != nil { break } if mt != websocket.MessageText { continue } var env api.Envelope _ = json.Unmarshal(raw, &env) if env.Type == api.MsgCommandRun { var p api.CommandRunPayload _ = env.UnmarshalPayload(&p) if p.Kind == api.JobBackup { t.Fatalf("disabled schedule still dispatched a backup: %+v", p) } } } var n int _ = st.DB().QueryRow(`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup'`, hostID).Scan(&n) if n != 0 { t.Errorf("disabled schedule produced %d backup rows", n) } }