phase 1: WS transport, enrollment, agent that hellos and heartbeats
Lands the protocol layer end-to-end: an agent can be enrolled through the operator UI, store credentials, dial back to the server over WS, complete the protocol_version handshake, and stay connected with periodic heartbeats. Server side: - P1-09 ws.Hub: one Conn per host_id, last-write-wins eviction, json envelope writer with a write mutex, reader, error envelopes. - P1-09 ws.AgentHandler: bearer-auth, accept upgrade, hello-stage (10s deadline, protocol_version checked against api.MinAgentProtocolVersion → ErrProtocolTooOld with help URL on reject), main read loop, defer hub register/unregister. - P1-10 POST /api/agents/enroll consumes a one-time token, mints a persistent agent bearer (sha-256 stored), creates a host row. - P1-10 POST /api/enrollment-tokens (operator, session-auth) issues a 1h one-time token. - P1-11 hello upserts agent_version + restic_version + protocol_version on the host row, flips status to online. - P1-12 heartbeat touches last_seen_at; background sweeper marks hosts offline after 90s without one. - store: hosts table accessors, host_schedule_version, enrollment_tokens FK on consumed_host dropped (audit-only field; the token gets burned before the host row exists). Agent side: - P1-13 internal/agent/config: yaml at /etc/restic-manager/agent.yaml, atomic Save (tmp+fsync+rename), Enrolled() helper. - P1-15 internal/agent/wsclient: dial with bearer + optional TLS cert pinning (sha-256 of leaf), exponential backoff with jitter (1s → 60s cap), heartbeat goroutine, fatal handling for ErrProtocolTooOld. - P1-15 wsclient.Enroll: HTTP POST /api/agents/enroll with sysinfo. - P1-17 internal/agent/sysinfo: hostname/OS/arch/restic-version collection. restic detected by `restic version` parse; absent restic doesn't block startup. - cmd/agent: -enroll-server / -enroll-token flags drive first-run enrollment then exit (so the install script can hand off to systemd to run the persistent service). End-to-end smoke verified: bootstrap → login → issue token → enroll → run agent → server logs `ws agent connected` with the right host_id and protocol_version 1. All tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,181 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// setupTestHub spins up a Server that exposes only /ws/agent against
|
||||
// a fresh sqlite store with one pre-enrolled host. Returns the URL,
|
||||
// the agent's bearer token, and the host ID.
|
||||
func setupTestHub(t *testing.T) (url string, token string, hostID string, st *store.Store, hub *Hub) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
var err error
|
||||
st, err = store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
|
||||
hub = NewHub()
|
||||
mux := stdhttp.NewServeMux()
|
||||
mux.Handle("/ws/agent", AgentHandler(HandlerDeps{Hub: hub, Store: st}))
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
// Pre-enroll a host directly via store (skipping HTTP).
|
||||
hostID = "01HJ8K70000000000000000000"
|
||||
token, _ = auth.NewToken()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateHost(context.Background(), store.Host{
|
||||
ID: hostID, Name: "h1", OS: "linux", Arch: "amd64",
|
||||
EnrolledAt: now,
|
||||
}, auth.HashToken(token), ""); err != nil {
|
||||
t.Fatalf("enroll: %v", err)
|
||||
}
|
||||
url = "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/agent"
|
||||
return
|
||||
}
|
||||
|
||||
func TestWSHelloAndHeartbeat(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, token, hostID, st, hub := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.CloseNow()
|
||||
|
||||
// Send hello.
|
||||
hello := api.HelloPayload{
|
||||
ProtocolVersion: api.CurrentProtocolVersion,
|
||||
AgentVersion: "0.1.0",
|
||||
ResticVersion: "0.17.1",
|
||||
Hostname: "h1",
|
||||
OS: api.OSLinux,
|
||||
Arch: api.ArchAmd64,
|
||||
}
|
||||
env, _ := api.Marshal(api.MsgHello, "", hello)
|
||||
raw, _ := json.Marshal(env)
|
||||
if err := c.Write(ctx, websocket.MessageText, raw); err != nil {
|
||||
t.Fatalf("write hello: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the server to register us (registration happens after
|
||||
// the hello-handler returns; give it up to 1s).
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for !hub.Connected(hostID) && time.Now().Before(deadline) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
if !hub.Connected(hostID) {
|
||||
t.Fatal("host did not register on hub after hello")
|
||||
}
|
||||
|
||||
// Verify host row was marked online + has populated metadata.
|
||||
h, err := st.GetHost(context.Background(), hostID)
|
||||
if err != nil {
|
||||
t.Fatalf("get host: %v", err)
|
||||
}
|
||||
if h.Status != "online" || h.AgentVersion != "0.1.0" {
|
||||
t.Errorf("host after hello: %+v", h)
|
||||
}
|
||||
|
||||
// Send a heartbeat — server should touch last_seen.
|
||||
hb := api.HeartbeatPayload{SentAt: time.Now().UTC()}
|
||||
env, _ = api.Marshal(api.MsgHeartbeat, "", hb)
|
||||
raw, _ = json.Marshal(env)
|
||||
preTouch := h.LastSeenAt
|
||||
_ = c.Write(ctx, websocket.MessageText, raw)
|
||||
|
||||
// Wait briefly for server to process.
|
||||
deadline = time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
h2, _ := st.GetHost(context.Background(), hostID)
|
||||
if h2.LastSeenAt != nil && (preTouch == nil || h2.LastSeenAt.After(*preTouch)) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Error("heartbeat did not update last_seen_at")
|
||||
}
|
||||
|
||||
func TestWSRejectsOldProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, token, _, _, _ := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer " + token}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.CloseNow()
|
||||
|
||||
hello := api.HelloPayload{ProtocolVersion: 0} // below minimum
|
||||
env, _ := api.Marshal(api.MsgHello, "", hello)
|
||||
raw, _ := json.Marshal(env)
|
||||
_ = c.Write(ctx, websocket.MessageText, raw)
|
||||
|
||||
// Server should send an error envelope, then close.
|
||||
mt, body, err := c.Read(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
t.Fatalf("frame type: %v", mt)
|
||||
}
|
||||
var got api.Envelope
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.Type != api.MsgError {
|
||||
t.Errorf("expected error envelope, got %q", got.Type)
|
||||
}
|
||||
var ep api.ErrorPayload
|
||||
_ = got.UnmarshalPayload(&ep)
|
||||
if ep.Code != api.ErrProtocolTooOld {
|
||||
t.Errorf("error code: %q", ep.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSRejectsBadToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
url, _, _, _, _ := setupTestHub(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPHeader: stdhttp.Header{"Authorization": []string{"Bearer wrong"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("dial should fail")
|
||||
}
|
||||
if res == nil || res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
if res != nil {
|
||||
t.Errorf("status: %d", res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user