store: lowercase username, email/disable helpers, last-admin count
This commit is contained in:
+134
-38
@@ -5,45 +5,57 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateUser inserts a new user. The caller is responsible for
|
// CreateUser inserts a row. Username is lowercase-normalised so the
|
||||||
// generating an ID (typically a ULID) and hashing the password.
|
// case-insensitive unique index from migration 0017 doesn't surprise
|
||||||
|
// callers who insert 'Alice' and look up 'alice'.
|
||||||
func (s *Store) CreateUser(ctx context.Context, u User) error {
|
func (s *Store) CreateUser(ctx context.Context, u User) error {
|
||||||
|
u.Username = strings.ToLower(strings.TrimSpace(u.Username))
|
||||||
|
must := 0
|
||||||
|
if u.MustChangePassword {
|
||||||
|
must = 1
|
||||||
|
}
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
`INSERT INTO users (id, username, password_hash, role, created_at)
|
`INSERT INTO users (id, username, password_hash, role, email,
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
must_change_password, created_at)
|
||||||
u.ID, u.Username, u.PasswordHash, string(u.Role), u.CreatedAt.UTC().Format(time.RFC3339Nano))
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
u.ID, u.Username, u.PasswordHash, string(u.Role),
|
||||||
|
nullable(u.Email), must,
|
||||||
|
u.CreatedAt.UTC().Format(time.RFC3339Nano))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("store: create user: %w", err)
|
return fmt.Errorf("store: create user: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByUsername looks up a user by their (case-sensitive) username.
|
// GetUserByUsername resolves a user case-insensitively.
|
||||||
// Returns ErrNotFound if no row matches.
|
|
||||||
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||||
row := s.db.QueryRowContext(ctx,
|
row := s.db.QueryRowContext(ctx,
|
||||||
`SELECT id, username, password_hash, role, created_at, last_login_at
|
`SELECT id, username, password_hash, role, email, disabled_at,
|
||||||
FROM users WHERE username = ?`, username)
|
must_change_password, created_at, last_login_at
|
||||||
return scanUser(row)
|
FROM users WHERE LOWER(username) = LOWER(?)`, username)
|
||||||
|
return scanUser(row.Scan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
|
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
|
||||||
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
row := s.db.QueryRowContext(ctx,
|
row := s.db.QueryRowContext(ctx,
|
||||||
`SELECT id, username, password_hash, role, created_at, last_login_at
|
`SELECT id, username, password_hash, role, email, disabled_at,
|
||||||
|
must_change_password, created_at, last_login_at
|
||||||
FROM users WHERE id = ?`, id)
|
FROM users WHERE id = ?`, id)
|
||||||
return scanUser(row)
|
return scanUser(row.Scan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers returns every user, sorted by username. Used by surfaces
|
// ListUsers returns every user, sorted by username. Used by surfaces
|
||||||
// that need to render a user-id → username map (audit log filter,
|
// that need to render a user-id → username map (audit log filter,
|
||||||
// "ack'd by" projections).
|
// "ack'd by" projections) and the user-management page.
|
||||||
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
||||||
rows, err := s.db.QueryContext(ctx,
|
rows, err := s.db.QueryContext(ctx,
|
||||||
`SELECT id, username, password_hash, role, created_at, last_login_at
|
`SELECT id, username, password_hash, role, email, disabled_at,
|
||||||
|
must_change_password, created_at, last_login_at
|
||||||
FROM users ORDER BY username`)
|
FROM users ORDER BY username`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("store: list users: %w", err)
|
return nil, fmt.Errorf("store: list users: %w", err)
|
||||||
@@ -51,21 +63,11 @@ func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
|||||||
defer func() { _ = rows.Close() }()
|
defer func() { _ = rows.Close() }()
|
||||||
var out []User
|
var out []User
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var u User
|
u, err := scanUser(rows.Scan)
|
||||||
var role string
|
if err != nil {
|
||||||
var lastLogin sql.NullString
|
return nil, err
|
||||||
var created string
|
|
||||||
if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil {
|
|
||||||
return nil, fmt.Errorf("store: scan user row: %w", err)
|
|
||||||
}
|
}
|
||||||
u.Role = Role(role)
|
out = append(out, *u)
|
||||||
t, _ := time.Parse(time.RFC3339Nano, created)
|
|
||||||
u.CreatedAt = t
|
|
||||||
if lastLogin.Valid {
|
|
||||||
t, _ := time.Parse(time.RFC3339Nano, lastLogin.String)
|
|
||||||
u.LastLoginAt = &t
|
|
||||||
}
|
|
||||||
out = append(out, u)
|
|
||||||
}
|
}
|
||||||
return out, rows.Err()
|
return out, rows.Err()
|
||||||
}
|
}
|
||||||
@@ -80,6 +82,19 @@ func (s *Store) CountUsers(ctx context.Context) (int, error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountEnabledAdmins returns the number of users with role='admin'
|
||||||
|
// AND disabled_at IS NULL. Used by the last-admin guard before
|
||||||
|
// disable / role-demote operations.
|
||||||
|
func (s *Store) CountEnabledAdmins(ctx context.Context) (int, error) {
|
||||||
|
var n int
|
||||||
|
if err := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT COUNT(*) FROM users WHERE role = 'admin' AND disabled_at IS NULL`,
|
||||||
|
).Scan(&n); err != nil {
|
||||||
|
return 0, fmt.Errorf("store: count admins: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarkUserLogin records a successful authentication.
|
// MarkUserLogin records a successful authentication.
|
||||||
func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error {
|
func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
@@ -91,28 +106,109 @@ func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func scanUser(row *sql.Row) (*User, error) {
|
// SetUserEmail replaces the email field. Empty string clears it.
|
||||||
|
func (s *Store) SetUserEmail(ctx context.Context, id, email string) error {
|
||||||
|
em := strings.ToLower(strings.TrimSpace(email))
|
||||||
|
var v any
|
||||||
|
if em == "" {
|
||||||
|
v = nil
|
||||||
|
} else {
|
||||||
|
v = em
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET email = ? WHERE id = ?`, v, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set user email: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserRole changes a user's role.
|
||||||
|
func (s *Store) SetUserRole(ctx context.Context, id string, role Role) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET role = ? WHERE id = ?`, string(role), id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set user role: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableUser sets disabled_at = when. Idempotent on already-disabled
|
||||||
|
// rows (no-op).
|
||||||
|
func (s *Store) DisableUser(ctx context.Context, id string, when time.Time) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET disabled_at = ?
|
||||||
|
WHERE id = ? AND disabled_at IS NULL`,
|
||||||
|
when.UTC().Format(time.RFC3339Nano), id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: disable user: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableUser clears disabled_at.
|
||||||
|
func (s *Store) EnableUser(ctx context.Context, id string) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET disabled_at = NULL WHERE id = ?`, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: enable user: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMustChangePassword toggles the must_change_password flag.
|
||||||
|
func (s *Store) SetMustChangePassword(ctx context.Context, id string, must bool) error {
|
||||||
|
v := 0
|
||||||
|
if must {
|
||||||
|
v = 1
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET must_change_password = ? WHERE id = ?`, v, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set must_change_password: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPasswordHash stores a new password_hash and clears the
|
||||||
|
// must_change_password flag in one go.
|
||||||
|
func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET password_hash = ?, must_change_password = 0 WHERE id = ?`,
|
||||||
|
hash, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set password: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanUser(scan func(...any) error) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
var role string
|
var role string
|
||||||
var lastLogin sql.NullString
|
var email, disabledAt, lastLogin sql.NullString
|
||||||
|
var must int
|
||||||
var created string
|
var created string
|
||||||
if err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil {
|
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
|
||||||
|
&email, &disabledAt, &must, &created, &lastLogin); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("store: scan user: %w", err)
|
return nil, fmt.Errorf("store: scan user: %w", err)
|
||||||
}
|
}
|
||||||
u.Role = Role(role)
|
u.Role = Role(role)
|
||||||
t, err := time.Parse(time.RFC3339Nano, created)
|
if email.Valid {
|
||||||
if err != nil {
|
v := email.String
|
||||||
return nil, fmt.Errorf("store: parse created_at: %w", err)
|
u.Email = &v
|
||||||
}
|
}
|
||||||
|
if disabledAt.Valid {
|
||||||
|
t, _ := time.Parse(time.RFC3339Nano, disabledAt.String)
|
||||||
|
u.DisabledAt = &t
|
||||||
|
}
|
||||||
|
u.MustChangePassword = must == 1
|
||||||
|
t, _ := time.Parse(time.RFC3339Nano, created)
|
||||||
u.CreatedAt = t
|
u.CreatedAt = t
|
||||||
if lastLogin.Valid {
|
if lastLogin.Valid {
|
||||||
t, err := time.Parse(time.RFC3339Nano, lastLogin.String)
|
t, _ := time.Parse(time.RFC3339Nano, lastLogin.String)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("store: parse last_login_at: %w", err)
|
|
||||||
}
|
|
||||||
u.LastLoginAt = &t
|
u.LastLoginAt = &t
|
||||||
}
|
}
|
||||||
return &u, nil
|
return &u, nil
|
||||||
|
|||||||
@@ -131,6 +131,40 @@ func TestSessionLifecycle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateUserLowercasesUsername(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
if err := s.CreateUser(ctx, User{
|
||||||
|
ID: "u1", Username: "Alice",
|
||||||
|
PasswordHash: "x", Role: RoleAdmin, CreatedAt: now,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
got, err := s.GetUserByUsername(ctx, "alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get lower: %v", err)
|
||||||
|
}
|
||||||
|
if got.Username != "alice" {
|
||||||
|
t.Errorf("stored username: got %q want %q", got.Username, "alice")
|
||||||
|
}
|
||||||
|
got, err = s.GetUserByUsername(ctx, "ALICE")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get upper: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != "u1" {
|
||||||
|
t.Errorf("upper-case lookup missed: got %+v", got)
|
||||||
|
}
|
||||||
|
if err := s.CreateUser(ctx, User{
|
||||||
|
ID: "u2", Username: "AlIcE",
|
||||||
|
PasswordHash: "x", Role: RoleAdmin, CreatedAt: now,
|
||||||
|
}); err == nil {
|
||||||
|
t.Error("duplicate (different case) should fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
s := openTestStore(t)
|
s := openTestStore(t)
|
||||||
|
|||||||
Reference in New Issue
Block a user