// pending_ws_test.go — end-to-end test of the announce → pending WS // → admin accept → bearer push round trip (P2-18b/c). package http import ( "context" "crypto/ed25519" "crypto/rand" "encoding/base64" "encoding/json" stdhttp "net/http" "net/url" "strings" "testing" "time" "github.com/coder/websocket" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // TestPendingWSNonceSignAcceptFlow: simulate an agent. Announce → // open pending WS → sign nonce → admin accept (with repo creds) → // expect 'enrolled' message with bearer. func TestPendingWSNonceSignAcceptFlow(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("ed25519: %v", err) } // Pre-seed pending row directly (bypass the announce HTTP path // since announce coverage lives in announce_test.go). pendingID := ulid.Make().String() if err := st.CreatePendingHost(context.Background(), &store.PendingHost{ ID: pendingID, Hostname: "ann-host", OS: "linux", Arch: "amd64", AgentVersion: "1.0", ResticVersion: "0.17", PublicKey: pub, Fingerprint: store.FingerprintForKey(pub), AnnouncedFromIP: "127.0.0.1", FirstSeenAt: time.Now().UTC(), LastSeenAt: time.Now().UTC(), ExpiresAt: time.Now().UTC().Add(time.Hour), }); err != nil { t.Fatalf("seed: %v", err) } // Open the pending WS. wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) defer dialCancel() c, res, err := websocket.Dial(dialCtx, wsURL, nil) if err != nil { t.Fatalf("dial pending ws: %v", err) } if res != nil && res.Body != nil { _ = res.Body.Close() } t.Cleanup(func() { _ = c.CloseNow() }) // Read nonce. rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second) _, raw, err := c.Read(rctx) rcancel() if err != nil { t.Fatalf("read nonce: %v", err) } var nm nonceMessage if err := json.Unmarshal(raw, &nm); err != nil { t.Fatalf("unmarshal nonce: %v", err) } nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce) // Sign + reply. sig := ed25519.Sign(priv, nonce) reply, _ := json.Marshal(signedNonceMessage{ Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig), }) wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second) if err := c.Write(wctx, websocket.MessageText, reply); err != nil { wcancel() t.Fatalf("write signed nonce: %v", err) } wcancel() // Wait briefly so the server's hub.register completes before we // fire accept. deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { if srv.pendingHub.get(pendingID) != nil { break } time.Sleep(20 * time.Millisecond) } // Admin POST accept (form-encoded, with cookie). cookie := loginAsAdmin(t, st) form := url.Values{ "repo_url": {"rest:http://r/x"}, "repo_username": {"u"}, "repo_password": {"p"}, } req, _ := stdhttp.NewRequest("POST", ts.URL+"/api/pending-hosts/"+pendingID+"/accept", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(cookie) resAccept, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("accept: %v", err) } defer resAccept.Body.Close() if resAccept.StatusCode != stdhttp.StatusOK { t.Fatalf("accept status: %d", resAccept.StatusCode) } // Expect 'enrolled' message + close. rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second) _, raw2, err := c.Read(rctx2) rcancel2() if err != nil { t.Fatalf("read enrolled: %v", err) } var em enrolledMessage if err := json.Unmarshal(raw2, &em); err != nil { t.Fatalf("unmarshal enrolled: %v", err) } if em.Type != "enrolled" || em.Bearer == "" || em.HostID == "" { t.Fatalf("enrolled payload bad: %+v", em) } // Pending row should be gone. if _, err := st.GetPendingHost(context.Background(), pendingID); err == nil { t.Error("pending row should have been deleted on accept") } // Real host row should exist. if _, err := st.GetHost(context.Background(), em.HostID); err != nil { t.Errorf("host row not created: %v", err) } } // TestPendingWSBadSignatureClosed: server closes the WS when the // signature does not verify against the row's public key. func TestPendingWSBadSignatureClosed(t *testing.T) { t.Parallel() srv, ts, st := rawTestServer(t) _ = srv // Two distinct keypairs — agent signs with the wrong one. pubReal, _, _ := ed25519.GenerateKey(rand.Reader) _, privAttacker, _ := ed25519.GenerateKey(rand.Reader) pendingID := ulid.Make().String() if err := st.CreatePendingHost(context.Background(), &store.PendingHost{ ID: pendingID, Hostname: "bad-host", OS: "linux", Arch: "amd64", PublicKey: pubReal, Fingerprint: store.FingerprintForKey(pubReal), AnnouncedFromIP: "127.0.0.1", FirstSeenAt: time.Now().UTC(), LastSeenAt: time.Now().UTC(), ExpiresAt: time.Now().UTC().Add(time.Hour), }); err != nil { t.Fatalf("seed: %v", err) } wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) defer dialCancel() c, res, err := websocket.Dial(dialCtx, wsURL, nil) if err != nil { t.Fatalf("dial: %v", err) } if res != nil && res.Body != nil { _ = res.Body.Close() } defer func() { _ = c.CloseNow() }() // Read nonce. rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second) _, raw, _ := c.Read(rctx) rcancel() var nm nonceMessage _ = json.Unmarshal(raw, &nm) nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce) // Sign with the wrong key. sig := ed25519.Sign(privAttacker, nonce) reply, _ := json.Marshal(signedNonceMessage{ Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig), }) wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second) _ = c.Write(wctx, websocket.MessageText, reply) wcancel() // Server should close. Read until error. rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second) _, _, err = c.Read(rctx2) rcancel2() if err == nil { t.Fatal("expected ws to close on bad signature") } }