diff --git a/internal/store/users.go b/internal/store/users.go index 1a74528..06cd4a1 100644 --- a/internal/store/users.go +++ b/internal/store/users.go @@ -5,45 +5,57 @@ import ( "database/sql" "errors" "fmt" + "strings" "time" ) -// CreateUser inserts a new user. The caller is responsible for -// generating an ID (typically a ULID) and hashing the password. +// CreateUser inserts a row. Username is lowercase-normalised so the +// case-insensitive unique index from migration 0017 doesn't surprise +// callers who insert 'Alice' and look up 'alice'. func (s *Store) CreateUser(ctx context.Context, u User) error { + u.Username = strings.ToLower(strings.TrimSpace(u.Username)) + must := 0 + if u.MustChangePassword { + must = 1 + } _, err := s.db.ExecContext(ctx, - `INSERT INTO users (id, username, password_hash, role, created_at) - VALUES (?, ?, ?, ?, ?)`, - u.ID, u.Username, u.PasswordHash, string(u.Role), u.CreatedAt.UTC().Format(time.RFC3339Nano)) + `INSERT INTO users (id, username, password_hash, role, email, + must_change_password, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + u.ID, u.Username, u.PasswordHash, string(u.Role), + nullable(u.Email), must, + u.CreatedAt.UTC().Format(time.RFC3339Nano)) if err != nil { return fmt.Errorf("store: create user: %w", err) } return nil } -// GetUserByUsername looks up a user by their (case-sensitive) username. -// Returns ErrNotFound if no row matches. +// GetUserByUsername resolves a user case-insensitively. func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, username, password_hash, role, created_at, last_login_at - FROM users WHERE username = ?`, username) - return scanUser(row) + `SELECT id, username, password_hash, role, email, disabled_at, + must_change_password, created_at, last_login_at + FROM users WHERE LOWER(username) = LOWER(?)`, username) + return scanUser(row.Scan) } // GetUserByID looks up a user by id. Returns ErrNotFound on miss. func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, username, password_hash, role, created_at, last_login_at + `SELECT id, username, password_hash, role, email, disabled_at, + must_change_password, created_at, last_login_at FROM users WHERE id = ?`, id) - return scanUser(row) + return scanUser(row.Scan) } // ListUsers returns every user, sorted by username. Used by surfaces // that need to render a user-id → username map (audit log filter, -// "ack'd by" projections). +// "ack'd by" projections) and the user-management page. func (s *Store) ListUsers(ctx context.Context) ([]User, error) { rows, err := s.db.QueryContext(ctx, - `SELECT id, username, password_hash, role, created_at, last_login_at + `SELECT id, username, password_hash, role, email, disabled_at, + must_change_password, created_at, last_login_at FROM users ORDER BY username`) if err != nil { return nil, fmt.Errorf("store: list users: %w", err) @@ -51,21 +63,11 @@ func (s *Store) ListUsers(ctx context.Context) ([]User, error) { defer func() { _ = rows.Close() }() var out []User for rows.Next() { - var u User - var role string - var lastLogin sql.NullString - var created string - if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil { - return nil, fmt.Errorf("store: scan user row: %w", err) + u, err := scanUser(rows.Scan) + if err != nil { + return nil, err } - u.Role = Role(role) - t, _ := time.Parse(time.RFC3339Nano, created) - u.CreatedAt = t - if lastLogin.Valid { - t, _ := time.Parse(time.RFC3339Nano, lastLogin.String) - u.LastLoginAt = &t - } - out = append(out, u) + out = append(out, *u) } return out, rows.Err() } @@ -80,6 +82,19 @@ func (s *Store) CountUsers(ctx context.Context) (int, error) { return n, nil } +// CountEnabledAdmins returns the number of users with role='admin' +// AND disabled_at IS NULL. Used by the last-admin guard before +// disable / role-demote operations. +func (s *Store) CountEnabledAdmins(ctx context.Context) (int, error) { + var n int + if err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM users WHERE role = 'admin' AND disabled_at IS NULL`, + ).Scan(&n); err != nil { + return 0, fmt.Errorf("store: count admins: %w", err) + } + return n, nil +} + // MarkUserLogin records a successful authentication. func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, @@ -91,28 +106,109 @@ func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) er return nil } -func scanUser(row *sql.Row) (*User, error) { +// SetUserEmail replaces the email field. Empty string clears it. +func (s *Store) SetUserEmail(ctx context.Context, id, email string) error { + em := strings.ToLower(strings.TrimSpace(email)) + var v any + if em == "" { + v = nil + } else { + v = em + } + _, err := s.db.ExecContext(ctx, + `UPDATE users SET email = ? WHERE id = ?`, v, id) + if err != nil { + return fmt.Errorf("store: set user email: %w", err) + } + return nil +} + +// SetUserRole changes a user's role. +func (s *Store) SetUserRole(ctx context.Context, id string, role Role) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET role = ? WHERE id = ?`, string(role), id) + if err != nil { + return fmt.Errorf("store: set user role: %w", err) + } + return nil +} + +// DisableUser sets disabled_at = when. Idempotent on already-disabled +// rows (no-op). +func (s *Store) DisableUser(ctx context.Context, id string, when time.Time) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET disabled_at = ? + WHERE id = ? AND disabled_at IS NULL`, + when.UTC().Format(time.RFC3339Nano), id) + if err != nil { + return fmt.Errorf("store: disable user: %w", err) + } + return nil +} + +// EnableUser clears disabled_at. +func (s *Store) EnableUser(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET disabled_at = NULL WHERE id = ?`, id) + if err != nil { + return fmt.Errorf("store: enable user: %w", err) + } + return nil +} + +// SetMustChangePassword toggles the must_change_password flag. +func (s *Store) SetMustChangePassword(ctx context.Context, id string, must bool) error { + v := 0 + if must { + v = 1 + } + _, err := s.db.ExecContext(ctx, + `UPDATE users SET must_change_password = ? WHERE id = ?`, v, id) + if err != nil { + return fmt.Errorf("store: set must_change_password: %w", err) + } + return nil +} + +// SetPasswordHash stores a new password_hash and clears the +// must_change_password flag in one go. +func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET password_hash = ?, must_change_password = 0 WHERE id = ?`, + hash, id) + if err != nil { + return fmt.Errorf("store: set password: %w", err) + } + return nil +} + +func scanUser(scan func(...any) error) (*User, error) { var u User var role string - var lastLogin sql.NullString + var email, disabledAt, lastLogin sql.NullString + var must int var created string - if err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil { + if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role, + &email, &disabledAt, &must, &created, &lastLogin); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("store: scan user: %w", err) } u.Role = Role(role) - t, err := time.Parse(time.RFC3339Nano, created) - if err != nil { - return nil, fmt.Errorf("store: parse created_at: %w", err) + if email.Valid { + v := email.String + u.Email = &v } + if disabledAt.Valid { + t, _ := time.Parse(time.RFC3339Nano, disabledAt.String) + u.DisabledAt = &t + } + u.MustChangePassword = must == 1 + t, _ := time.Parse(time.RFC3339Nano, created) u.CreatedAt = t if lastLogin.Valid { - t, err := time.Parse(time.RFC3339Nano, lastLogin.String) - if err != nil { - return nil, fmt.Errorf("store: parse last_login_at: %w", err) - } + t, _ := time.Parse(time.RFC3339Nano, lastLogin.String) u.LastLoginAt = &t } return &u, nil diff --git a/internal/store/users_test.go b/internal/store/users_test.go index e1eae99..a7684a9 100644 --- a/internal/store/users_test.go +++ b/internal/store/users_test.go @@ -131,6 +131,40 @@ func TestSessionLifecycle(t *testing.T) { } } +func TestCreateUserLowercasesUsername(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + if err := s.CreateUser(ctx, User{ + ID: "u1", Username: "Alice", + PasswordHash: "x", Role: RoleAdmin, CreatedAt: now, + }); err != nil { + t.Fatalf("create: %v", err) + } + got, err := s.GetUserByUsername(ctx, "alice") + if err != nil { + t.Fatalf("get lower: %v", err) + } + if got.Username != "alice" { + t.Errorf("stored username: got %q want %q", got.Username, "alice") + } + got, err = s.GetUserByUsername(ctx, "ALICE") + if err != nil { + t.Fatalf("get upper: %v", err) + } + if got.ID != "u1" { + t.Errorf("upper-case lookup missed: got %+v", got) + } + if err := s.CreateUser(ctx, User{ + ID: "u2", Username: "AlIcE", + PasswordHash: "x", Role: RoleAdmin, CreatedAt: now, + }); err == nil { + t.Error("duplicate (different case) should fail") + } +} + func TestEnrollmentTokenSingleUse(t *testing.T) { t.Parallel() s := openTestStore(t)