Files
restic-manager/internal/server/ws/hub_test.go
T
steve e871b05b38
CI / Test (linux/amd64) (pull_request) Successful in 34s
CI / Lint (pull_request) Failing after 16s
CI / Build (windows/amd64) (pull_request) Successful in 22s
CI / Build (linux/amd64) (pull_request) Successful in 20s
CI / Build (linux/arm64) (pull_request) Successful in 21s
lint: drive baseline to zero, drop only-new-issues gate
Cleanup pass over the repo so CI can enforce lint going forward
without the only-new-issues escape hatch:

* gofumpt -w across the tree (31 hits, all formatting)
* misspell --fix (25 hits, US-locale spelling) — but reverted on
  api.JobCancelled = "cancelled" since that literal is the wire +
  DB CHECK constraint value, plus matched the case in store/fleet.go
  back to "cancelled" and added //nolint:misspell on both for the
  next time someone reaches for the auto-fix
* Wrap every `defer rows.Close()` / `defer stmt.Close()` /
  `defer res.Body.Close()` in `defer func() { _ = .Close() }()`
  to satisfy errcheck without losing the close itself
* websocket.Dial callers (1 prod, 4 tests) now capture + close the
  upgrade response Body — coder/websocket can return res with a nil
  Body on success, so the test deferred-closes guard against that
* Annotate the two genuine-by-design nilerr cases with //nolint
  comments explaining why nil-on-error is the contract (cookie
  missing = no session; ctx cancelled mid-backoff = clean shutdown)
* Add brief godoc on the 10 exported const groups + types that
  revive flagged (api.HostOS/HostArch/JobKind/JobStatus/LogStream/
  ErrorCode, restic.EventKind, store.Role, web.FS)
* Drop the unused (*Server).userByID method
* Inline the unparam baseView(active) — every UI page is under
  the dashboard primary nav today

Result: `golangci-lint run ./...` reports 0 issues. CI lint job
no longer needs only-new-issues: true; X-06 follow-up entry in
tasks.md removed.
2026-05-03 16:15:17 +01:00

199 lines
5.3 KiB
Go

package ws
import (
"context"
"encoding/json"
stdhttp "net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// setupTestHub spins up a Server that exposes only /ws/agent against
// a fresh sqlite store with one pre-enrolled host. Returns the URL,
// the agent's bearer token, and the host ID.
func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) {
t.Helper()
dir := t.TempDir()
var err error
st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("store: %v", err)
}
t.Cleanup(func() { _ = st.Close() })
hub = NewHub()
mux := stdhttp.NewServeMux()
mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st}))
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
// Pre-enroll a host directly via store (skipping HTTP).
hostID = "01HJ8K70000000000000000000"
token, _ = auth.NewToken()
now := time.Now().UTC()
if err := st.CreateHost(context.Background(), store.Host{
ID: hostID, Name: "h1", OS: "linux", Arch: "amd64",
EnrolledAt: now,
}, auth.HashToken(token), ""); err != nil {
t.Fatalf("enroll: %v", err)
}
url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent"
return
}
func TestWSHelloAndHeartbeat(t *testing.T) {
t.Parallel()
url, token, hostID, st, hub := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
// Send hello.
hello := api.HelloPayload{
ProtocolVersion: api.CurrentProtocolVersion,
AgentVersion: "0.1.0",
ResticVersion: "0.17.1",
Hostname: "h1",
OS: api.OSLinux,
Arch: api.ArchAmd64,
}
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
t.Fatalf("write hello: %v", err)
}
// Wait for the server to register us (registration happens after
// the hello-handler returns; give it up to 1s).
deadline := time.Now().Add(time.Second)
for !hub.Connected(hostID) && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
if !hub.Connected(hostID) {
t.Fatal("host did not register on hub after hello")
}
// Verify host row was marked online + has populated metadata.
h, err := st.GetHost(context.Background(), hostID)
if err != nil {
t.Fatalf("get host: %v", err)
}
if h.Status != "online" || h.AgentVersion != "0.1.0" {
t.Errorf("host after hello: %+v", h)
}
// Send a heartbeat — server should touch last_seen.
hb := api.HeartbeatPayload{SentAt: time.Now().UTC()}
env, _ = api.Marshal(api.MsgHeartbeat, "", hb)
raw, _ = json.Marshal(env)
preTouch := h.LastSeenAt
_ = c.Write(ctx, websocket.MessageText, raw)
// Wait briefly for server to process.
deadline = time.Now().Add(time.Second)
for time.Now().Before(deadline) {
h2, _ := st.GetHost(context.Background(), hostID)
if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Error("heartbeat did not update last_seen_at")
}
func TestWSRejectsOldProtocol(t *testing.T) {
t.Parallel()
url, token, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
hello := api.HelloPayload{ProtocolVersion: 0} // below minimum
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
_ = c.Write(ctx, websocket.MessageText, raw)
// Server should send an error envelope, then close.
mt, body, err := c.Read(ctx)
if err != nil {
t.Fatalf("read: %v", err)
}
if mt != websocket.MessageText {
t.Fatalf("frame type: %v", mt)
}
var got api.Envelope
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Type != api.MsgError {
t.Errorf("expected error envelope, got %q", got.Type)
}
var ep api.ErrorPayload
_ = got.UnmarshalPayload(&ep)
if ep.Code != api.ErrProtocolTooOld {
t.Errorf("error code: %q", ep.Code)
}
}
func TestWSRejectsBadToken(t *testing.T) {
t.Parallel()
url, _, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}},
})
if res != nil {
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
}
if err == nil {
t.Fatal("dial should fail")
}
if res == nil || res.StatusCode != stdhttp.StatusUnauthorized {
if res != nil {
t.Errorf("status: %d", res.StatusCode)
}
}
}