Files
restic-manager/internal/server/ws/hub_test.go
T
steve dd7b37a5c1
CI / Test (linux/amd64) (pull_request) Successful in 21s
CI / Lint (pull_request) Successful in 24s
CI / Build (windows/amd64) (pull_request) Successful in 20s
CI / Build (linux/amd64) (pull_request) Successful in 21s
CI / Build (linux/arm64) (pull_request) Successful in 20s
lint: align local gofumpt rules with golangci-lint v2.5.0
Bumping CI to v2.5.0 surfaced two new gofumpt findings (in two test
files that gofumpt v2.1.6 considered fine). Local re-format with
the matching tool brings them in line.

Pre-commit hook config: prepend $GOPATH/bin to PATH inside the hook
entry so gofumpt + golangci-lint resolve when ~/go/bin isn't on the
operator's interactive shell PATH (common — go install puts them
there but PATH config varies). Without this, the hooks fail with
'Executable not found' even when the tools are installed.

Pin the Makefile setup target to v2.5.0 so a fresh clone gets the
same binary CI runs — keeps pre-commit and CI from drifting again.
2026-05-03 21:31:47 +01:00

199 lines
5.3 KiB
Go

package ws
import (
"context"
"encoding/json"
stdhttp "net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// setupTestHub spins up a Server that exposes only /ws/agent against
// a fresh sqlite store with one pre-enrolled host. Returns the URL,
// the agent's bearer token, and the host ID.
func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) {
t.Helper()
dir := t.TempDir()
var err error
st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("store: %v", err)
}
t.Cleanup(func() { _ = st.Close() })
hub = NewHub()
mux := stdhttp.NewServeMux()
mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st}))
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
// Pre-enroll a host directly via store (skipping HTTP).
hostID = "01HJ8K70000000000000000000"
token, _ = auth.NewToken()
now := time.Now().UTC()
if err := st.CreateHost(context.Background(), store.Host{
ID: hostID, Name: "h1", OS: "linux", Arch: "amd64",
EnrolledAt: now,
}, auth.HashToken(token), ""); err != nil {
t.Fatalf("enroll: %v", err)
}
url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent"
return url, token, hostID, st, hub
}
func TestWSHelloAndHeartbeat(t *testing.T) {
t.Parallel()
url, token, hostID, st, hub := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
// Send hello.
hello := api.HelloPayload{
ProtocolVersion: api.CurrentProtocolVersion,
AgentVersion: "0.1.0",
ResticVersion: "0.17.1",
Hostname: "h1",
OS: api.OSLinux,
Arch: api.ArchAmd64,
}
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
t.Fatalf("write hello: %v", err)
}
// Wait for the server to register us (registration happens after
// the hello-handler returns; give it up to 1s).
deadline := time.Now().Add(time.Second)
for !hub.Connected(hostID) && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
if !hub.Connected(hostID) {
t.Fatal("host did not register on hub after hello")
}
// Verify host row was marked online + has populated metadata.
h, err := st.GetHost(context.Background(), hostID)
if err != nil {
t.Fatalf("get host: %v", err)
}
if h.Status != "online" || h.AgentVersion != "0.1.0" {
t.Errorf("host after hello: %+v", h)
}
// Send a heartbeat — server should touch last_seen.
hb := api.HeartbeatPayload{SentAt: time.Now().UTC()}
env, _ = api.Marshal(api.MsgHeartbeat, "", hb)
raw, _ = json.Marshal(env)
preTouch := h.LastSeenAt
_ = c.Write(ctx, websocket.MessageText, raw)
// Wait briefly for server to process.
deadline = time.Now().Add(time.Second)
for time.Now().Before(deadline) {
h2, _ := st.GetHost(context.Background(), hostID)
if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Error("heartbeat did not update last_seen_at")
}
func TestWSRejectsOldProtocol(t *testing.T) {
t.Parallel()
url, token, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
})
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.CloseNow()
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
hello := api.HelloPayload{ProtocolVersion: 0} // below minimum
env, _ := api.Marshal(api.MsgHello, "", hello)
raw, _ := json.Marshal(env)
_ = c.Write(ctx, websocket.MessageText, raw)
// Server should send an error envelope, then close.
mt, body, err := c.Read(ctx)
if err != nil {
t.Fatalf("read: %v", err)
}
if mt != websocket.MessageText {
t.Fatalf("frame type: %v", mt)
}
var got api.Envelope
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Type != api.MsgError {
t.Errorf("expected error envelope, got %q", got.Type)
}
var ep api.ErrorPayload
_ = got.UnmarshalPayload(&ep)
if ep.Code != api.ErrProtocolTooOld {
t.Errorf("error code: %q", ep.Code)
}
}
func TestWSRejectsBadToken(t *testing.T) {
t.Parallel()
url, _, _, _, _ := setupTestHub(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}},
})
if res != nil {
defer func() {
if res != nil && res.Body != nil {
_ = res.Body.Close()
}
}()
}
if err == nil {
t.Fatal("dial should fail")
}
if res == nil || res.StatusCode != stdhttp.StatusUnauthorized {
if res != nil {
t.Errorf("status: %d", res.StatusCode)
}
}
}