package http import ( "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "io" stdhttp "net/http" "net/url" "strings" "testing" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) func sha256Hex(s string) string { h := sha256.Sum256([]byte(s)) return hex.EncodeToString(h[:]) } func TestSetupGetValidToken(t *testing.T) { t.Parallel() // /setup renders HTML, so we need a real UI renderer. srv, ts, _ := rawTestServerWithUI(t) urlBase := ts.URL now := time.Now().UTC() uid := ulid.Make().String() if err := srv.deps.Store.CreateUser(t.Context(), store.User{ ID: uid, Username: "newbie", PasswordHash: "", Role: store.RoleOperator, CreatedAt: now, MustChangePassword: true, }); err != nil { t.Fatalf("create: %v", err) } raw := "raw-token-1234567890" hash := sha256Hex(raw) if err := srv.deps.Store.SetSetupToken(context.Background(), store.SetupToken{ UserID: uid, TokenHash: hash, ExpiresAt: now.Add(time.Hour), CreatedAt: now, }); err != nil { t.Fatalf("set token: %v", err) } res, err := stdhttp.Get(urlBase + "/setup?token=" + raw) if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusOK { t.Errorf("status: got %d want 200", res.StatusCode) } body, _ := io.ReadAll(res.Body) if !strings.Contains(string(body), "newbie") { t.Errorf("expected username in body: %s", body) } } func TestSetupGetExpiredToken(t *testing.T) { t.Parallel() // /setup renders HTML, so we need a real UI renderer. srv, ts, _ := rawTestServerWithUI(t) urlBase := ts.URL now := time.Now().UTC() uid := ulid.Make().String() _ = srv.deps.Store.CreateUser(t.Context(), store.User{ ID: uid, Username: "stale", PasswordHash: "", Role: store.RoleViewer, CreatedAt: now, MustChangePassword: true, }) raw := "expired-token" _ = srv.deps.Store.SetSetupToken(context.Background(), store.SetupToken{ UserID: uid, TokenHash: sha256Hex(raw), ExpiresAt: now.Add(-time.Minute), CreatedAt: now.Add(-2 * time.Hour), }) res, err := stdhttp.Get(urlBase + "/setup?token=" + raw) if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusGone { t.Errorf("status: got %d want 410", res.StatusCode) } } func TestSetupPostHappyPath(t *testing.T) { t.Parallel() srv, ts, _ := rawTestServerWithUI(t) urlBase := ts.URL now := time.Now().UTC() uid := ulid.Make().String() _ = srv.deps.Store.CreateUser(t.Context(), store.User{ ID: uid, Username: "newbie", PasswordHash: "", Role: store.RoleOperator, CreatedAt: now, MustChangePassword: true, }) raw := "happy-token" _ = srv.deps.Store.SetSetupToken(t.Context(), store.SetupToken{ UserID: uid, TokenHash: sha256Hex(raw), ExpiresAt: now.Add(time.Hour), CreatedAt: now, }) form := url.Values{} form.Set("token", raw) form.Set("password", "averylongpassword") form.Set("password_confirm", "averylongpassword") req, _ := stdhttp.NewRequest("POST", urlBase+"/setup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { return stdhttp.ErrUseLastResponse }} res, err := c.Do(req) if err != nil { t.Fatalf("POST: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther { t.Errorf("status: got %d want 303", res.StatusCode) } if res.Header.Get("Location") != "/" { t.Errorf("location: got %q want /", res.Header.Get("Location")) } // Token is consumed. if _, err := srv.deps.Store.LookupSetupToken(t.Context(), sha256Hex(raw)); err == nil { t.Error("token should be deleted after consumption") } // User can now log in via the normal route. logBody, _ := json.Marshal(map[string]string{ "username": "newbie", "password": "averylongpassword", }) loginRes, _ := stdhttp.Post(urlBase+"/api/auth/login", "application/json", bytes.NewReader(logBody)) defer loginRes.Body.Close() if loginRes.StatusCode != stdhttp.StatusOK { body, _ := io.ReadAll(loginRes.Body) t.Errorf("login: %d %s", loginRes.StatusCode, body) } }