package http import ( "bytes" "context" "encoding/json" "io" stdhttp "net/http" "net/http/httptest" "path/filepath" "testing" "time" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // newTestServerWithHub mirrors newTestServer but plugs in a real // ws.Hub so /ws/agent is available. func newTestServerWithHub(t *testing.T) (*Server, string, *store.Store) { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) keyPath := filepath.Join(dir, "secret.key") _ = crypto.GenerateKeyFile(keyPath) key, _ := crypto.LoadKeyFromFile(keyPath) aead, _ := crypto.NewAEAD(key) deps := Deps{ Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath}, Store: st, AEAD: aead, Hub: ws.NewHub(), } s := New(deps) ts := httptest.NewServer(s.srv.Handler) t.Cleanup(ts.Close) return s, ts.URL, st } func TestEnrollmentBadToken(t *testing.T) { t.Parallel() _, url, _ := newTestServerWithHub(t) body, _ := json.Marshal(enrollRequest{ Token: "no-such-token", HostName: "host1", OS: api.OSLinux, Arch: api.ArchAmd64, AgentVersion: "0.1", ResticVersion: "0.17", }) res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("post: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("status: %d", res.StatusCode) } } func TestEnrollmentHappyPath(t *testing.T) { t.Parallel() _, url, st := newTestServerWithHub(t) // Issue a token directly via the store (skipping the operator UI). rawToken, _ := auth.NewToken() if err := st.CreateEnrollmentToken(context.Background(), auth.HashToken(rawToken), 5*time.Minute, "", ""); err != nil { t.Fatalf("issue: %v", err) } body, _ := json.Marshal(enrollRequest{ Token: rawToken, HostName: "test-host", OS: api.OSLinux, Arch: api.ArchAmd64, AgentVersion: "0.1", ResticVersion: "0.17", }) res, err := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("post: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusCreated { buf, _ := io.ReadAll(res.Body) t.Fatalf("status %d: %s", res.StatusCode, buf) } var er enrollResponse if err := json.NewDecoder(res.Body).Decode(&er); err != nil { t.Fatalf("decode: %v", err) } if er.HostID == "" || er.AgentToken == "" { t.Errorf("missing fields in response: %+v", er) } // Token must not be reusable. res2, _ := stdhttp.Post(url+"/api/agents/enroll", "application/json", bytes.NewReader(body)) defer res2.Body.Close() if res2.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("re-enrollment with same token should fail, got %d", res2.StatusCode) } // Host row exists with matching agent_token_hash. got, err := st.LookupHostByAgentToken(context.Background(), auth.HashToken(er.AgentToken)) if err != nil { t.Fatalf("lookup by token: %v", err) } if got.Name != "test-host" || got.OS != "linux" { t.Errorf("host fields: %+v", got) } }