Files
restic-manager/internal/server/http/p2r01_ws_test.go
T
steve d692272d10 P2R-01 follow-up: WS-path tests + drop unused retention from backup dispatch
Adds p2r01_ws_test.go covering the two paths the original commit's
in-process tests couldn't reach without a live conn:

- maybeAutoInit dispatches command.run(init) on first hello when creds
  are bound, skips on second hello once a job row exists, and skips
  entirely when the host has no creds.
- dispatchScheduledJob iterates a schedule's source groups and emits
  one backup per group with the right Tag/Includes; persists job rows
  with actor_kind=schedule + scheduled_id; no-ops on a disabled
  schedule.

Drops RetentionPolicy from the per-group Run-now and schedule.fire
backup payloads — the agent's RunBackup ignores it (forget is the
only consumer). Adds Hub.Conn() so tests can grab the live *Conn
post-hello.
2026-05-03 11:00:45 +01:00

403 lines
13 KiB
Go

// p2r01_ws_test.go — integration tests for the WS-touching pieces of
// P2R-01: auto-init dispatch on hello, and dispatchScheduledJob's
// schedule.fire → command.run-per-group resolution.
package http
import (
"context"
"encoding/json"
"net/http/httptest"
stdhttp "net/http"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// agentDial connects to the server's /ws/agent as a freshly-enrolled
// host and returns the conn + a cleanup. Caller is expected to send
// hello.
func agentDial(t *testing.T, srv *Server, ts *httptest.Server, hostID, token string) *websocket.Conn {
t.Helper()
url := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() { _ = c.CloseNow() })
return c
}
// readEnvelope blocks until one envelope arrives or the test times out.
func readEnvelope(t *testing.T, c *websocket.Conn) api.Envelope {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
mt, raw, err := c.Read(ctx)
if err != nil {
t.Fatalf("ws read: %v", err)
}
if mt != websocket.MessageText {
t.Fatalf("ws read: non-text frame %v", mt)
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
t.Fatalf("ws unmarshal: %v: %s", err, raw)
}
return env
}
// drainUntil reads envelopes until a wantType arrives or the test
// times out. Returns the matching envelope (and the others read along
// the way, ignored). On-hello pushes config.update + schedule.set +
// (sometimes) command.run; tests want to skip past the prefix to the
// envelope they care about.
func drainUntil(t *testing.T, c *websocket.Conn, wantType api.MessageType) api.Envelope {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
env := readEnvelope(t, c)
if env.Type == wantType {
return env
}
}
t.Fatalf("timed out waiting for %s", wantType)
return api.Envelope{}
}
// enrolHostForWS pre-enrols a host with bound repo creds so the server
// will treat it as ready to receive command.run.
func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (hostID, token string) {
t.Helper()
hostID = ulid.Make().String()
token, _ = auth.NewToken()
if err := st.CreateHost(context.Background(), store.Host{
ID: hostID, Name: name, OS: "linux", Arch: "amd64",
EnrolledAt: time.Now().UTC(),
}, auth.HashToken(token), ""); err != nil {
t.Fatalf("create host: %v", err)
}
enc, err := srv.encryptRepoCreds(repoCredsBlob{
RepoURL: "rest:http://r/x", RepoUsername: "u", RepoPassword: "p",
}, []byte("host:"+hostID))
if err != nil {
t.Fatalf("encrypt: %v", err)
}
if err := st.SetHostCredentials(context.Background(), hostID, enc); err != nil {
t.Fatalf("set creds: %v", err)
}
return
}
func sendHello(t *testing.T, c *websocket.Conn, hostname string) {
t.Helper()
env, _ := api.Marshal(api.MsgHello, "", api.HelloPayload{
ProtocolVersion: api.CurrentProtocolVersion,
AgentVersion: "test",
ResticVersion: "0.17",
Hostname: hostname,
OS: api.OSLinux, Arch: api.ArchAmd64,
})
raw, _ := json.Marshal(env)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
t.Fatalf("write hello: %v", err)
}
}
// rawTestServer wires Server up against an httptest server and returns
// the inner Server pointer + the URL.
func rawTestServer(t *testing.T) (*Server, *httptest.Server, *store.Store) {
t.Helper()
srv, _, st := newTestServerWithHub(t)
ts := httptest.NewServer(srv.srv.Handler)
t.Cleanup(ts.Close)
return srv, ts, st
}
// connFromHub fetches the live *ws.Conn for hostID from the hub.
// Polls briefly because the WS handler registers the conn just after
// the OnHello callback returns.
func connFromHub(t *testing.T, srv *Server, hostID string) *ws.Conn {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if c := srv.deps.Hub.Conn(hostID); c != nil {
return c
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("hub never registered conn for %s", hostID)
return nil
}
// ----- auto-init dispatch -----------------------------------------
func TestAutoInitDispatchedOnFirstHelloOnly(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
hostID, token := enrolHostForWS(t, srv, st, "auto-init-host")
c := agentDial(t, srv, ts, hostID, token)
sendHello(t, c, "auto-init-host")
// Expect config.update + schedule.set + command.run(init) in some
// order. drainUntil walks past the first two to find the init.
env := drainUntil(t, c, api.MsgCommandRun)
var p api.CommandRunPayload
if err := env.UnmarshalPayload(&p); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if p.Kind != api.JobInit {
t.Fatalf("first command.run kind: %s, want init", p.Kind)
}
// Mark the init job succeeded so HasJobOfKind sees terminal state.
if err := st.MarkJobFinished(context.Background(), p.JobID, "succeeded", 0, nil, "", time.Now().UTC()); err != nil {
t.Fatalf("mark finished: %v", err)
}
// Reconnect — the second hello must NOT dispatch another init.
_ = c.Close(websocket.StatusNormalClosure, "test")
// Brief wait so the hub unregisters the old conn before we open a
// new one (otherwise Register supersedes the old one, which is a
// race the production code already handles but the test doesn't
// need to fight).
time.Sleep(50 * time.Millisecond)
c2 := agentDial(t, srv, ts, hostID, token)
sendHello(t, c2, "auto-init-host")
// Expect config.update + schedule.set, then a quiet read that
// times out — no second init.
deadline := time.Now().Add(1500 * time.Millisecond)
for time.Now().Before(deadline) {
ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond)
mt, raw, err := c2.Read(ctx)
cancel()
if err != nil {
break
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
if err := json.Unmarshal(raw, &env); err != nil {
continue
}
if env.Type == api.MsgCommandRun {
var p api.CommandRunPayload
_ = env.UnmarshalPayload(&p)
if p.Kind == api.JobInit {
t.Fatalf("second hello re-dispatched init (job_id=%s) — gate is broken", p.JobID)
}
}
}
}
func TestAutoInitSkippedWhenNoCreds(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
hostID := ulid.Make().String()
token, _ := auth.NewToken()
if err := st.CreateHost(context.Background(), store.Host{
ID: hostID, Name: "no-creds-host", OS: "linux", Arch: "amd64",
EnrolledAt: time.Now().UTC(),
}, auth.HashToken(token), ""); err != nil {
t.Fatalf("create host: %v", err)
}
c := agentDial(t, srv, ts, hostID, token)
sendHello(t, c, "no-creds-host")
// On-hello sends a schedule.set (config.update is skipped because
// no creds). We should NOT see a command.run(init).
deadline := time.Now().Add(1500 * time.Millisecond)
for time.Now().Before(deadline) {
ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond)
mt, raw, err := c.Read(ctx)
cancel()
if err != nil {
break
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
_ = json.Unmarshal(raw, &env)
if env.Type == api.MsgCommandRun {
var p api.CommandRunPayload
_ = env.UnmarshalPayload(&p)
t.Fatalf("auto-init dispatched without creds (kind=%s)", p.Kind)
}
}
}
// ----- dispatchScheduledJob ---------------------------------------
func TestDispatchScheduledJobIteratesGroups(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
hostID, token := enrolHostForWS(t, srv, st, "fire-host")
// Create two source groups + one schedule covering both.
g1 := ulid.Make().String()
g2 := ulid.Make().String()
for _, g := range []*store.SourceGroup{
{ID: g1, HostID: hostID, Name: "etc", Includes: []string{"/etc"}},
{ID: g2, HostID: hostID, Name: "home", Includes: []string{"/home"}},
} {
if err := st.CreateSourceGroup(context.Background(), g); err != nil {
t.Fatalf("group: %v", err)
}
}
sid := ulid.Make().String()
if err := st.CreateSchedule(context.Background(), &store.Schedule{
ID: sid, HostID: hostID,
CronExpr: "0 3 * * *", Enabled: true,
SourceGroupIDs: []string{g1, g2},
}); err != nil {
t.Fatalf("schedule: %v", err)
}
// Mark a successful init job up front so the auto-init path
// doesn't fire and pollute the envelope sequence we're measuring.
if err := st.CreateJob(context.Background(), store.Job{
ID: ulid.Make().String(), HostID: hostID, Kind: "init",
ActorKind: "system", CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("seed init: %v", err)
}
c := agentDial(t, srv, ts, hostID, token)
sendHello(t, c, "fire-host")
// Wait for the schedule.set push so we know the hub is connected.
_ = drainUntil(t, c, api.MsgScheduleSet)
// Resolve the conn from the hub and call dispatchScheduledJob
// directly — same path the WS handler invokes on schedule.fire.
conn := connFromHub(t, srv, hostID)
srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC())
// Two backups should be queued, one per group. Read both.
got := map[string]api.CommandRunPayload{}
deadline := time.Now().Add(3 * time.Second)
for len(got) < 2 && time.Now().Before(deadline) {
ctx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond)
mt, raw, err := c.Read(ctx)
cancel()
if err != nil {
break
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
_ = json.Unmarshal(raw, &env)
if env.Type != api.MsgCommandRun {
continue
}
var p api.CommandRunPayload
_ = env.UnmarshalPayload(&p)
if p.Kind != api.JobBackup {
continue
}
got[p.Tag] = p
}
if len(got) != 2 {
t.Fatalf("expected 2 backups (one per group), got %d: %+v", len(got), got)
}
if !equalStrings(got["etc"].Includes, []string{"/etc"}) {
t.Errorf("etc backup includes: %v", got["etc"].Includes)
}
if !equalStrings(got["home"].Includes, []string{"/home"}) {
t.Errorf("home backup includes: %v", got["home"].Includes)
}
// Two job rows should exist for this host with kind=backup,
// actor_kind=schedule, scheduled_id=sid.
var n int
if err := st.DB().QueryRow(
`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup' AND actor_kind = 'schedule' AND scheduled_id = ?`,
hostID, sid).Scan(&n); err != nil {
t.Fatalf("count: %v", err)
}
if n != 2 {
t.Errorf("scheduled backup jobs: got %d, want 2", n)
}
}
func TestDispatchScheduledJobDisabledNoOp(t *testing.T) {
t.Parallel()
srv, ts, st := rawTestServer(t)
hostID, token := enrolHostForWS(t, srv, st, "disabled-host")
gid := ulid.Make().String()
if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{
ID: gid, HostID: hostID, Name: "default", Includes: []string{"/etc"},
}); err != nil {
t.Fatalf("group: %v", err)
}
sid := ulid.Make().String()
if err := st.CreateSchedule(context.Background(), &store.Schedule{
ID: sid, HostID: hostID,
CronExpr: "0 3 * * *", Enabled: false, // disabled
SourceGroupIDs: []string{gid},
}); err != nil {
t.Fatalf("schedule: %v", err)
}
if err := st.CreateJob(context.Background(), store.Job{
ID: ulid.Make().String(), HostID: hostID, Kind: "init",
ActorKind: "system", CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("seed init: %v", err)
}
c := agentDial(t, srv, ts, hostID, token)
sendHello(t, c, "disabled-host")
_ = drainUntil(t, c, api.MsgScheduleSet)
conn := connFromHub(t, srv, hostID)
srv.dispatchScheduledJob(context.Background(), hostID, conn, sid, time.Now().UTC())
// No backups should be queued.
deadline := time.Now().Add(800 * time.Millisecond)
for time.Now().Before(deadline) {
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
mt, raw, err := c.Read(ctx)
cancel()
if err != nil {
break
}
if mt != websocket.MessageText {
continue
}
var env api.Envelope
_ = json.Unmarshal(raw, &env)
if env.Type == api.MsgCommandRun {
var p api.CommandRunPayload
_ = env.UnmarshalPayload(&p)
if p.Kind == api.JobBackup {
t.Fatalf("disabled schedule still dispatched a backup: %+v", p)
}
}
}
var n int
_ = st.DB().QueryRow(`SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'backup'`, hostID).Scan(&n)
if n != 0 {
t.Errorf("disabled schedule produced %d backup rows", n)
}
}