From 12391abef0c87e3c38a72f0794f5acf46cc310cf Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Tue, 5 May 2026 09:06:54 +0100 Subject: [PATCH] store: user_setup_tokens CRUD + cleanup-expired --- internal/store/setup_tokens.go | 93 +++++++++++++++++++++ internal/store/setup_tokens_test.go | 120 ++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 internal/store/setup_tokens.go create mode 100644 internal/store/setup_tokens_test.go diff --git a/internal/store/setup_tokens.go b/internal/store/setup_tokens.go new file mode 100644 index 0000000..161f8ca --- /dev/null +++ b/internal/store/setup_tokens.go @@ -0,0 +1,93 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" +) + +// SetSetupToken inserts a row, replacing any existing token for +// this user (single-outstanding invariant). Caller passes a hash — +// raw tokens are never persisted. +func (s *Store) SetSetupToken(ctx context.Context, t SetupToken) error { + _, err := s.db.ExecContext(ctx, + `INSERT OR REPLACE INTO user_setup_tokens + (user_id, token_hash, expires_at, created_at, created_by) + VALUES (?, ?, ?, ?, ?)`, + t.UserID, t.TokenHash, + t.ExpiresAt.UTC().Format(time.RFC3339Nano), + t.CreatedAt.UTC().Format(time.RFC3339Nano), + nullable(t.CreatedBy)) + if err != nil { + return fmt.Errorf("store: set setup token: %w", err) + } + return nil +} + +// LookupSetupToken resolves a token hash to its row. Returns +// ErrNotFound for missing tokens. Expiry is NOT checked here — +// callers must compare ExpiresAt themselves so they can record +// 'expired' as a distinct outcome (audit-able) from 'never existed'. +func (s *Store) LookupSetupToken(ctx context.Context, tokenHash string) (*SetupToken, error) { + row := s.db.QueryRowContext(ctx, + `SELECT user_id, token_hash, expires_at, created_at, created_by + FROM user_setup_tokens WHERE token_hash = ?`, tokenHash) + return scanSetupToken(row.Scan) +} + +// GetSetupTokenByUserID returns the row for one user. Used by the +// edit page to know whether a 'Regenerate setup link' button should +// show as 'Generate' or 'Regenerate'. Returns ErrNotFound when no +// outstanding token exists. +func (s *Store) GetSetupTokenByUserID(ctx context.Context, userID string) (*SetupToken, error) { + row := s.db.QueryRowContext(ctx, + `SELECT user_id, token_hash, expires_at, created_at, created_by + FROM user_setup_tokens WHERE user_id = ?`, userID) + return scanSetupToken(row.Scan) +} + +// DeleteSetupToken removes the row for a user (single-use cleanup +// after /setup completes successfully). +func (s *Store) DeleteSetupToken(ctx context.Context, userID string) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM user_setup_tokens WHERE user_id = ?`, userID) + if err != nil { + return fmt.Errorf("store: delete setup token: %w", err) + } + return nil +} + +// CleanupExpiredSetupTokens removes rows whose expires_at has passed. +// Returns the number of rows deleted. Called from the maintenance +// ticker every minute. +func (s *Store) CleanupExpiredSetupTokens(ctx context.Context, now time.Time) (int64, error) { + res, err := s.db.ExecContext(ctx, + `DELETE FROM user_setup_tokens WHERE expires_at < ?`, + now.UTC().Format(time.RFC3339Nano)) + if err != nil { + return 0, fmt.Errorf("store: cleanup setup tokens: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} + +func scanSetupToken(scan func(...any) error) (*SetupToken, error) { + var t SetupToken + var createdBy sql.NullString + var expiresAt, createdAt string + if err := scan(&t.UserID, &t.TokenHash, &expiresAt, &createdAt, &createdBy); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("store: scan setup token: %w", err) + } + t.ExpiresAt, _ = time.Parse(time.RFC3339Nano, expiresAt) + t.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt) + if createdBy.Valid { + v := createdBy.String + t.CreatedBy = &v + } + return &t, nil +} diff --git a/internal/store/setup_tokens_test.go b/internal/store/setup_tokens_test.go new file mode 100644 index 0000000..8aa29b5 --- /dev/null +++ b/internal/store/setup_tokens_test.go @@ -0,0 +1,120 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/oklog/ulid/v2" +) + +func newSetupTokenTestStore(t *testing.T) (*Store, string, string) { + t.Helper() + st, err := Open(context.Background(), filepath.Join(t.TempDir(), "rm.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + uid := ulid.Make().String() + creator := ulid.Make().String() + now := time.Now().UTC() + if err := st.CreateUser(context.Background(), User{ + ID: creator, Username: "creator", PasswordHash: "x", + Role: RoleAdmin, CreatedAt: now, + }); err != nil { + t.Fatalf("create creator: %v", err) + } + if err := st.CreateUser(context.Background(), User{ + ID: uid, Username: "target", PasswordHash: "", + Role: RoleOperator, CreatedAt: now, MustChangePassword: true, + }); err != nil { + t.Fatalf("create target: %v", err) + } + return st, uid, creator +} + +func TestSetupTokenSetAndLookup(t *testing.T) { + t.Parallel() + st, uid, creator := newSetupTokenTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + if err := st.SetSetupToken(ctx, SetupToken{ + UserID: uid, TokenHash: "abc123", + ExpiresAt: now.Add(time.Hour), + CreatedAt: now, CreatedBy: &creator, + }); err != nil { + t.Fatalf("set: %v", err) + } + got, err := st.LookupSetupToken(ctx, "abc123") + if err != nil { + t.Fatalf("lookup: %v", err) + } + if got.UserID != uid { + t.Errorf("user_id: got %q want %q", got.UserID, uid) + } +} + +func TestSetupTokenReplaces(t *testing.T) { + t.Parallel() + st, uid, creator := newSetupTokenTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + _ = st.SetSetupToken(ctx, SetupToken{ + UserID: uid, TokenHash: "old", + ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator, + }) + _ = st.SetSetupToken(ctx, SetupToken{ + UserID: uid, TokenHash: "new", + ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator, + }) + if _, err := st.LookupSetupToken(ctx, "old"); err == nil { + t.Error("old token should be gone") + } + if _, err := st.LookupSetupToken(ctx, "new"); err != nil { + t.Errorf("new token should resolve: %v", err) + } +} + +func TestSetupTokenDelete(t *testing.T) { + t.Parallel() + st, uid, creator := newSetupTokenTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + _ = st.SetSetupToken(ctx, SetupToken{ + UserID: uid, TokenHash: "tk", + ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator, + }) + if err := st.DeleteSetupToken(ctx, uid); err != nil { + t.Fatalf("delete: %v", err) + } + if _, err := st.LookupSetupToken(ctx, "tk"); err == nil { + t.Error("deleted token should not resolve") + } +} + +func TestSetupTokenCleanupExpired(t *testing.T) { + t.Parallel() + st, uid, creator := newSetupTokenTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + _ = st.SetSetupToken(ctx, SetupToken{ + UserID: uid, TokenHash: "stale", + ExpiresAt: now.Add(-time.Hour), CreatedAt: now.Add(-2 * time.Hour), + CreatedBy: &creator, + }) + n, err := st.CleanupExpiredSetupTokens(ctx, now) + if err != nil { + t.Fatalf("cleanup: %v", err) + } + if n != 1 { + t.Errorf("cleanup count: got %d want 1", n) + } + if _, err := st.LookupSetupToken(ctx, "stale"); err == nil { + t.Error("stale token should be gone") + } +}