package store import ( "context" "database/sql" "errors" "fmt" "strings" "time" ) // CreateUser inserts a row. Username is lowercase-normalised so the // case-insensitive unique index from migration 0017 doesn't surprise // callers who insert 'Alice' and look up 'alice'. func (s *Store) CreateUser(ctx context.Context, u User) error { u.Username = strings.ToLower(strings.TrimSpace(u.Username)) must := 0 if u.MustChangePassword { must = 1 } _, err := s.db.ExecContext(ctx, `INSERT INTO users (id, username, password_hash, role, email, must_change_password, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, u.ID, u.Username, u.PasswordHash, string(u.Role), nullable(u.Email), must, u.CreatedAt.UTC().Format(time.RFC3339Nano)) if err != nil { return fmt.Errorf("store: create user: %w", err) } return nil } // GetUserByUsername resolves a user case-insensitively. func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { row := s.db.QueryRowContext(ctx, `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users WHERE LOWER(username) = LOWER(?)`, username) return scanUser(row.Scan) } // GetUserByID looks up a user by id. Returns ErrNotFound on miss. func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { row := s.db.QueryRowContext(ctx, `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users WHERE id = ?`, id) return scanUser(row.Scan) } // UserSort selects the column ListUsers orders by. OrderBy is // allowlisted in usersOrderColumn so callers can't inject SQL via // this field. Empty / unknown OrderBy falls back to "username". type UserSort struct { OrderBy string // "username" | "email" | "role" | "last_login_at" OrderAsc bool // false = DESC; true = ASC } // usersOrderColumn validates s.OrderBy and returns the SQL fragment. // last_login_at gets a NULL-tail trick so users who've never logged // in sort to the bottom regardless of asc/desc — matches operator // intuition ("show me real activity" not "show me NULLs first"). func usersOrderColumn(col string, asc bool) string { dir := "DESC" if asc { dir = "ASC" } switch col { case "email": return fmt.Sprintf("email IS NULL, email %s, username", dir) case "role": return fmt.Sprintf("role %s, username", dir) case "last_login_at": return fmt.Sprintf("last_login_at IS NULL, last_login_at %s, username", dir) default: // username (and unknown) return fmt.Sprintf("username %s", dir) } } // ListUsers returns users sorted per UserSort. Default (zero value) // is username ASC. Used by the user-management page (sort headers) // and by surfaces that need a user-id → username map (audit log // filter, "ack'd by" projections) — those callers pass UserSort{}. func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) { asc := sort.OrderAsc if sort.OrderBy == "" { // Default: username ASC (alphabetical), matching pre-sort behaviour. asc = true } q := `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users ORDER BY ` + usersOrderColumn(sort.OrderBy, asc) rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, fmt.Errorf("store: list users: %w", err) } defer func() { _ = rows.Close() }() var out []User for rows.Next() { u, err := scanUser(rows.Scan) if err != nil { return nil, err } out = append(out, *u) } return out, rows.Err() } // CountUsers returns the total number of user rows. The first-run // bootstrap uses this to detect a fresh install. func (s *Store) CountUsers(ctx context.Context) (int, error) { var n int if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil { return 0, fmt.Errorf("store: count users: %w", err) } return n, nil } // CountEnabledAdmins returns the number of users with role='admin' // AND disabled_at IS NULL. Used by the last-admin guard before // disable / role-demote operations. func (s *Store) CountEnabledAdmins(ctx context.Context) (int, error) { var n int if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users WHERE role = 'admin' AND disabled_at IS NULL`, ).Scan(&n); err != nil { return 0, fmt.Errorf("store: count admins: %w", err) } return n, nil } // MarkUserLogin records a successful authentication. func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET last_login_at = ? WHERE id = ?`, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: mark login: %w", err) } return nil } // SetUserEmail replaces the email field. Empty string clears it. func (s *Store) SetUserEmail(ctx context.Context, id, email string) error { em := strings.ToLower(strings.TrimSpace(email)) var v any if em == "" { v = nil } else { v = em } _, err := s.db.ExecContext(ctx, `UPDATE users SET email = ? WHERE id = ?`, v, id) if err != nil { return fmt.Errorf("store: set user email: %w", err) } return nil } // SetUserRole changes a user's role. func (s *Store) SetUserRole(ctx context.Context, id string, role Role) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET role = ? WHERE id = ?`, string(role), id) if err != nil { return fmt.Errorf("store: set user role: %w", err) } return nil } // DisableUser sets disabled_at = when. Idempotent on already-disabled // rows (no-op). func (s *Store) DisableUser(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET disabled_at = ? WHERE id = ? AND disabled_at IS NULL`, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: disable user: %w", err) } return nil } // EnableUser clears disabled_at. func (s *Store) EnableUser(ctx context.Context, id string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET disabled_at = NULL WHERE id = ?`, id) if err != nil { return fmt.Errorf("store: enable user: %w", err) } return nil } // SetMustChangePassword toggles the must_change_password flag. func (s *Store) SetMustChangePassword(ctx context.Context, id string, must bool) error { v := 0 if must { v = 1 } _, err := s.db.ExecContext(ctx, `UPDATE users SET must_change_password = ? WHERE id = ?`, v, id) if err != nil { return fmt.Errorf("store: set must_change_password: %w", err) } return nil } // SetPasswordHash stores a new password_hash and clears the // must_change_password flag in one go. func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ?, must_change_password = 0 WHERE id = ?`, hash, id) if err != nil { return fmt.Errorf("store: set password: %w", err) } return nil } func scanUser(scan func(...any) error) (*User, error) { var u User var role string var email, disabledAt, lastLogin sql.NullString var must int var created string if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role, &email, &disabledAt, &must, &created, &lastLogin); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("store: scan user: %w", err) } u.Role = Role(role) if email.Valid { v := email.String u.Email = &v } if disabledAt.Valid { t, _ := time.Parse(time.RFC3339Nano, disabledAt.String) u.DisabledAt = &t } u.MustChangePassword = must == 1 t, _ := time.Parse(time.RFC3339Nano, created) u.CreatedAt = t if lastLogin.Valid { t, _ := time.Parse(time.RFC3339Nano, lastLogin.String) u.LastLoginAt = &t } return &u, nil }