package http import ( "bytes" "context" "encoding/json" "io" stdhttp "net/http" "strings" "testing" "time" "github.com/coder/websocket" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // makePushHost is like makeHTTPHost but mints a known agent token so // the test can dial /ws/agent as the host. Returns (hostID, raw token). func makePushHost(t *testing.T, st *store.Store) (string, string) { t.Helper() const id = "01HSCHEDPUSH00000000000000" tok, _ := auth.NewToken() if err := st.CreateHost(context.Background(), store.Host{ ID: id, Name: "ph", OS: "linux", Arch: "amd64", AgentVersion: "dev", ResticVersion: "0.16.0", ProtocolVersion: 1, EnrolledAt: time.Now().UTC(), }, auth.HashToken(tok), ""); err != nil { t.Fatalf("create host: %v", err) } return id, tok } // readUntilType pumps messages from the WS until one of the wanted // types arrives or ctx times out. Returns the matched envelope. // Useful because the on-hello path may push several messages // (config.update first if creds exist, schedule.set, …). func readUntilType(ctx context.Context, t *testing.T, c *websocket.Conn, want api.MessageType) api.Envelope { t.Helper() for { _, raw, err := c.Read(ctx) if err != nil { t.Fatalf("ws read waiting for %s: %v", want, err) } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { t.Fatalf("envelope: %v (raw=%s)", err, raw) } t.Logf("recv: type=%s payload=%s", env.Type, env.Payload) if env.Type == want { return env } } } func TestSchedulePushOnHelloAndAckRoundtrip(t *testing.T) { t.Parallel() srv, url, st := newTestServerWithHub(t) _ = srv cookie := loginAndCookie(t, url) hostID, agentToken := makePushHost(t, st) // Pre-populate one schedule so we have something to push. body, _ := json.Marshal(scheduleAPI{ Kind: "backup", CronExpr: "@hourly", Paths: []string{"/etc"}, Enabled: true, }) req, _ := stdhttp.NewRequest("POST", url+"/api/hosts/"+hostID+"/schedules", bytes.NewReader(body)) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("create schedule: %v", err) } got, _ := io.ReadAll(res.Body) res.Body.Close() if res.StatusCode != stdhttp.StatusCreated { t.Fatalf("create schedule: %d %s", res.StatusCode, got) } var created scheduleAPI _ = json.Unmarshal(got, &created) // Dial the WS as the agent and send hello. wsURL := "ws" + strings.TrimPrefix(url, "http") + "/ws/agent" ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c, _, err := websocket.Dial(ctx, wsURL, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + agentToken}}, }) if err != nil { t.Fatalf("dial: %v", err) } defer c.CloseNow() helloEnv, _ := api.Marshal(api.MsgHello, "", api.HelloPayload{ ProtocolVersion: api.CurrentProtocolVersion, AgentVersion: "test", ResticVersion: "test", Hostname: "ph", OS: api.OSLinux, Arch: api.ArchAmd64, }) raw, _ := json.Marshal(helloEnv) if err := c.Write(ctx, websocket.MessageText, raw); err != nil { t.Fatalf("write hello: %v", err) } // Server should push schedule.set (our host has no creds, so the // config.update branch is silently skipped). pushedEnv := readUntilType(ctx, t, c, api.MsgScheduleSet) var pushed api.ScheduleSetPayload if err := pushedEnv.UnmarshalPayload(&pushed); err != nil { t.Fatalf("decode payload: %v", err) } if pushed.Version != 1 { t.Fatalf("pushed version: got %d, want 1", pushed.Version) } if len(pushed.Schedules) != 1 || pushed.Schedules[0].ID != created.ID { t.Fatalf("pushed schedules: %+v", pushed.Schedules) } if pushed.Schedules[0].CronExpr != "@hourly" || len(pushed.Schedules[0].Paths) != 1 { t.Fatalf("schedule contents: %+v", pushed.Schedules[0]) } // Ack the version. Server should record it on the host row. ackEnv, _ := api.Marshal(api.MsgScheduleAck, "", api.ScheduleAckPayload{ Version: pushed.Version, AppliedAt: time.Now().UTC(), }) raw, _ = json.Marshal(ackEnv) if err := c.Write(ctx, websocket.MessageText, raw); err != nil { t.Fatalf("write ack: %v", err) } // Wait for applied_schedule_version to flip. deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { h, err := st.GetHost(context.Background(), hostID) if err == nil && h.AppliedScheduleVersion == pushed.Version { return } time.Sleep(20 * time.Millisecond) } h, _ := st.GetHost(context.Background(), hostID) t.Fatalf("applied_schedule_version did not advance: got %d, want %d", h.AppliedScheduleVersion, pushed.Version) } func TestSchedulePushOnCRUD(t *testing.T) { t.Parallel() srv, url, st := newTestServerWithHub(t) _ = srv cookie := loginAndCookie(t, url) hostID, agentToken := makePushHost(t, st) // Connect first so the CRUD push has somewhere to land. wsURL := "ws" + strings.TrimPrefix(url, "http") + "/ws/agent" ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c, _, err := websocket.Dial(ctx, wsURL, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + agentToken}}, }) if err != nil { t.Fatalf("dial: %v", err) } defer c.CloseNow() helloEnv, _ := api.Marshal(api.MsgHello, "", api.HelloPayload{ ProtocolVersion: api.CurrentProtocolVersion, AgentVersion: "test", ResticVersion: "test", Hostname: "ph", OS: api.OSLinux, Arch: api.ArchAmd64, }) raw, _ := json.Marshal(helloEnv) _ = c.Write(ctx, websocket.MessageText, raw) // Drain the on-hello schedule.set (will be version 0, empty list). first := readUntilType(ctx, t, c, api.MsgScheduleSet) var initial api.ScheduleSetPayload _ = first.UnmarshalPayload(&initial) if initial.Version != 0 || len(initial.Schedules) != 0 { t.Fatalf("initial push: %+v", initial) } // Now create a schedule via REST. The handler should fire a // schedule.set push asynchronously. body, _ := json.Marshal(scheduleAPI{ Kind: "backup", CronExpr: "*/30 * * * *", Paths: []string{"/var/lib"}, Enabled: true, }) req, _ := stdhttp.NewRequest("POST", url+"/api/hosts/"+hostID+"/schedules", bytes.NewReader(body)) req.AddCookie(cookie) req.Header.Set("Content-Type", "application/json") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("create: %v", err) } res.Body.Close() if res.StatusCode != stdhttp.StatusCreated { t.Fatalf("create: %d", res.StatusCode) } // Wait for the pushed schedule.set with version 1. pushed := readUntilType(ctx, t, c, api.MsgScheduleSet) var pl api.ScheduleSetPayload _ = pushed.UnmarshalPayload(&pl) if pl.Version != 1 || len(pl.Schedules) != 1 { t.Fatalf("push after create: %+v", pl) } }