package ws import ( "context" "encoding/json" stdhttp "net/http" "net/http/httptest" "path/filepath" "strings" "testing" "time" "github.com/coder/websocket" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // setupTestHub spins up a Server that exposes only /ws/agent against // a fresh sqlite store with one pre-enrolled host. Returns the URL, // the agent's bearer token, and the host ID. func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) { t.Helper() dir := t.TempDir() var err error st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) hub = NewHub() mux := stdhttp.NewServeMux() mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st})) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) // Pre-enroll a host directly via store (skipping HTTP). hostID = "01HJ8K70000000000000000000" token, _ = auth.NewToken() now := time.Now().UTC() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: "h1", OS: "linux", Arch: "amd64", EnrolledAt: now, }, auth.HashToken(token), ""); err != nil { t.Fatalf("enroll: %v", err) } url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent" return } func TestWSHelloAndHeartbeat(t *testing.T) { t.Parallel() url, token, hostID, st, hub := setupTestHub(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, }) if err != nil { t.Fatalf("dial: %v", err) } defer c.CloseNow() defer func() { if res != nil && res.Body != nil { _ = res.Body.Close() } }() // Send hello. hello := api.HelloPayload{ ProtocolVersion: api.CurrentProtocolVersion, AgentVersion: "0.1.0", ResticVersion: "0.17.1", Hostname: "h1", OS: api.OSLinux, Arch: api.ArchAmd64, } env, _ := api.Marshal(api.MsgHello, "", hello) raw, _ := json.Marshal(env) if err := c.Write(ctx, websocket.MessageText, raw); err != nil { t.Fatalf("write hello: %v", err) } // Wait for the server to register us (registration happens after // the hello-handler returns; give it up to 1s). deadline := time.Now().Add(time.Second) for !hub.Connected(hostID) && time.Now().Before(deadline) { time.Sleep(20 * time.Millisecond) } if !hub.Connected(hostID) { t.Fatal("host did not register on hub after hello") } // Verify host row was marked online + has populated metadata. h, err := st.GetHost(context.Background(), hostID) if err != nil { t.Fatalf("get host: %v", err) } if h.Status != "online" || h.AgentVersion != "0.1.0" { t.Errorf("host after hello: %+v", h) } // Send a heartbeat — server should touch last_seen. hb := api.HeartbeatPayload{SentAt: time.Now().UTC()} env, _ = api.Marshal(api.MsgHeartbeat, "", hb) raw, _ = json.Marshal(env) preTouch := h.LastSeenAt _ = c.Write(ctx, websocket.MessageText, raw) // Wait briefly for server to process. deadline = time.Now().Add(time.Second) for time.Now().Before(deadline) { h2, _ := st.GetHost(context.Background(), hostID) if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) { return } time.Sleep(20 * time.Millisecond) } t.Error("heartbeat did not update last_seen_at") } func TestWSRejectsOldProtocol(t *testing.T) { t.Parallel() url, token, _, _, _ := setupTestHub(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}}, }) if err != nil { t.Fatalf("dial: %v", err) } defer c.CloseNow() defer func() { if res != nil && res.Body != nil { _ = res.Body.Close() } }() hello := api.HelloPayload{ProtocolVersion: 0} // below minimum env, _ := api.Marshal(api.MsgHello, "", hello) raw, _ := json.Marshal(env) _ = c.Write(ctx, websocket.MessageText, raw) // Server should send an error envelope, then close. mt, body, err := c.Read(ctx) if err != nil { t.Fatalf("read: %v", err) } if mt != websocket.MessageText { t.Fatalf("frame type: %v", mt) } var got api.Envelope if err := json.Unmarshal(body, &got); err != nil { t.Fatalf("unmarshal: %v", err) } if got.Type != api.MsgError { t.Errorf("expected error envelope, got %q", got.Type) } var ep api.ErrorPayload _ = got.UnmarshalPayload(&ep) if ep.Code != api.ErrProtocolTooOld { t.Errorf("error code: %q", ep.Code) } } func TestWSRejectsBadToken(t *testing.T) { t.Parallel() url, _, _, _, _ := setupTestHub(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}}, }) if res != nil { defer func() { if res != nil && res.Body != nil { _ = res.Body.Close() } }() } if err == nil { t.Fatal("dial should fail") } if res == nil || res.StatusCode != stdhttp.StatusUnauthorized { if res != nil { t.Errorf("status: %d", res.StatusCode) } } }