store: user_setup_tokens CRUD + cleanup-expired
This commit is contained in:
@@ -0,0 +1,93 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetSetupToken inserts a row, replacing any existing token for
|
||||||
|
// this user (single-outstanding invariant). Caller passes a hash —
|
||||||
|
// raw tokens are never persisted.
|
||||||
|
func (s *Store) SetSetupToken(ctx context.Context, t SetupToken) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT OR REPLACE INTO user_setup_tokens
|
||||||
|
(user_id, token_hash, expires_at, created_at, created_by)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
t.UserID, t.TokenHash,
|
||||||
|
t.ExpiresAt.UTC().Format(time.RFC3339Nano),
|
||||||
|
t.CreatedAt.UTC().Format(time.RFC3339Nano),
|
||||||
|
nullable(t.CreatedBy))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set setup token: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupSetupToken resolves a token hash to its row. Returns
|
||||||
|
// ErrNotFound for missing tokens. Expiry is NOT checked here —
|
||||||
|
// callers must compare ExpiresAt themselves so they can record
|
||||||
|
// 'expired' as a distinct outcome (audit-able) from 'never existed'.
|
||||||
|
func (s *Store) LookupSetupToken(ctx context.Context, tokenHash string) (*SetupToken, error) {
|
||||||
|
row := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT user_id, token_hash, expires_at, created_at, created_by
|
||||||
|
FROM user_setup_tokens WHERE token_hash = ?`, tokenHash)
|
||||||
|
return scanSetupToken(row.Scan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupTokenByUserID returns the row for one user. Used by the
|
||||||
|
// edit page to know whether a 'Regenerate setup link' button should
|
||||||
|
// show as 'Generate' or 'Regenerate'. Returns ErrNotFound when no
|
||||||
|
// outstanding token exists.
|
||||||
|
func (s *Store) GetSetupTokenByUserID(ctx context.Context, userID string) (*SetupToken, error) {
|
||||||
|
row := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT user_id, token_hash, expires_at, created_at, created_by
|
||||||
|
FROM user_setup_tokens WHERE user_id = ?`, userID)
|
||||||
|
return scanSetupToken(row.Scan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSetupToken removes the row for a user (single-use cleanup
|
||||||
|
// after /setup completes successfully).
|
||||||
|
func (s *Store) DeleteSetupToken(ctx context.Context, userID string) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM user_setup_tokens WHERE user_id = ?`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: delete setup token: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpiredSetupTokens removes rows whose expires_at has passed.
|
||||||
|
// Returns the number of rows deleted. Called from the maintenance
|
||||||
|
// ticker every minute.
|
||||||
|
func (s *Store) CleanupExpiredSetupTokens(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
res, err := s.db.ExecContext(ctx,
|
||||||
|
`DELETE FROM user_setup_tokens WHERE expires_at < ?`,
|
||||||
|
now.UTC().Format(time.RFC3339Nano))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("store: cleanup setup tokens: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanSetupToken(scan func(...any) error) (*SetupToken, error) {
|
||||||
|
var t SetupToken
|
||||||
|
var createdBy sql.NullString
|
||||||
|
var expiresAt, createdAt string
|
||||||
|
if err := scan(&t.UserID, &t.TokenHash, &expiresAt, &createdAt, &createdBy); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("store: scan setup token: %w", err)
|
||||||
|
}
|
||||||
|
t.ExpiresAt, _ = time.Parse(time.RFC3339Nano, expiresAt)
|
||||||
|
t.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdAt)
|
||||||
|
if createdBy.Valid {
|
||||||
|
v := createdBy.String
|
||||||
|
t.CreatedBy = &v
|
||||||
|
}
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSetupTokenTestStore(t *testing.T) (*Store, string, string) {
|
||||||
|
t.Helper()
|
||||||
|
st, err := Open(context.Background(), filepath.Join(t.TempDir(), "rm.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = st.Close() })
|
||||||
|
uid := ulid.Make().String()
|
||||||
|
creator := ulid.Make().String()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := st.CreateUser(context.Background(), User{
|
||||||
|
ID: creator, Username: "creator", PasswordHash: "x",
|
||||||
|
Role: RoleAdmin, CreatedAt: now,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create creator: %v", err)
|
||||||
|
}
|
||||||
|
if err := st.CreateUser(context.Background(), User{
|
||||||
|
ID: uid, Username: "target", PasswordHash: "",
|
||||||
|
Role: RoleOperator, CreatedAt: now, MustChangePassword: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create target: %v", err)
|
||||||
|
}
|
||||||
|
return st, uid, creator
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTokenSetAndLookup(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st, uid, creator := newSetupTokenTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
if err := st.SetSetupToken(ctx, SetupToken{
|
||||||
|
UserID: uid, TokenHash: "abc123",
|
||||||
|
ExpiresAt: now.Add(time.Hour),
|
||||||
|
CreatedAt: now, CreatedBy: &creator,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("set: %v", err)
|
||||||
|
}
|
||||||
|
got, err := st.LookupSetupToken(ctx, "abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("lookup: %v", err)
|
||||||
|
}
|
||||||
|
if got.UserID != uid {
|
||||||
|
t.Errorf("user_id: got %q want %q", got.UserID, uid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTokenReplaces(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st, uid, creator := newSetupTokenTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
_ = st.SetSetupToken(ctx, SetupToken{
|
||||||
|
UserID: uid, TokenHash: "old",
|
||||||
|
ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator,
|
||||||
|
})
|
||||||
|
_ = st.SetSetupToken(ctx, SetupToken{
|
||||||
|
UserID: uid, TokenHash: "new",
|
||||||
|
ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator,
|
||||||
|
})
|
||||||
|
if _, err := st.LookupSetupToken(ctx, "old"); err == nil {
|
||||||
|
t.Error("old token should be gone")
|
||||||
|
}
|
||||||
|
if _, err := st.LookupSetupToken(ctx, "new"); err != nil {
|
||||||
|
t.Errorf("new token should resolve: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTokenDelete(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st, uid, creator := newSetupTokenTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
_ = st.SetSetupToken(ctx, SetupToken{
|
||||||
|
UserID: uid, TokenHash: "tk",
|
||||||
|
ExpiresAt: now.Add(time.Hour), CreatedAt: now, CreatedBy: &creator,
|
||||||
|
})
|
||||||
|
if err := st.DeleteSetupToken(ctx, uid); err != nil {
|
||||||
|
t.Fatalf("delete: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.LookupSetupToken(ctx, "tk"); err == nil {
|
||||||
|
t.Error("deleted token should not resolve")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTokenCleanupExpired(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st, uid, creator := newSetupTokenTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
_ = st.SetSetupToken(ctx, SetupToken{
|
||||||
|
UserID: uid, TokenHash: "stale",
|
||||||
|
ExpiresAt: now.Add(-time.Hour), CreatedAt: now.Add(-2 * time.Hour),
|
||||||
|
CreatedBy: &creator,
|
||||||
|
})
|
||||||
|
n, err := st.CleanupExpiredSetupTokens(ctx, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cleanup: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Errorf("cleanup count: got %d want 1", n)
|
||||||
|
}
|
||||||
|
if _, err := st.LookupSetupToken(ctx, "stale"); err == nil {
|
||||||
|
t.Error("stale token should be gone")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user