package store import ( "context" "database/sql" "errors" "fmt" "strings" "time" ) // CreateUser inserts a row. Username is lowercase-normalised so the // case-insensitive unique index from migration 0017 doesn't surprise // callers who insert 'Alice' and look up 'alice'. func (s *Store) CreateUser(ctx context.Context, u User) error { u.Username = strings.ToLower(strings.TrimSpace(u.Username)) must := 0 if u.MustChangePassword { must = 1 } _, err := s.db.ExecContext(ctx, `INSERT INTO users (id, username, password_hash, role, email, must_change_password, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, u.ID, u.Username, u.PasswordHash, string(u.Role), nullable(u.Email), must, u.CreatedAt.UTC().Format(time.RFC3339Nano)) if err != nil { return fmt.Errorf("store: create user: %w", err) } return nil } // GetUserByUsername resolves a user case-insensitively. func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { row := s.db.QueryRowContext(ctx, `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users WHERE LOWER(username) = LOWER(?)`, username) return scanUser(row.Scan) } // GetUserByID looks up a user by id. Returns ErrNotFound on miss. func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { row := s.db.QueryRowContext(ctx, `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users WHERE id = ?`, id) return scanUser(row.Scan) } // ListUsers returns every user, sorted by username. Used by surfaces // that need to render a user-id → username map (audit log filter, // "ack'd by" projections) and the user-management page. func (s *Store) ListUsers(ctx context.Context) ([]User, error) { rows, err := s.db.QueryContext(ctx, `SELECT id, username, password_hash, role, email, disabled_at, must_change_password, created_at, last_login_at FROM users ORDER BY username`) if err != nil { return nil, fmt.Errorf("store: list users: %w", err) } defer func() { _ = rows.Close() }() var out []User for rows.Next() { u, err := scanUser(rows.Scan) if err != nil { return nil, err } out = append(out, *u) } return out, rows.Err() } // CountUsers returns the total number of user rows. The first-run // bootstrap uses this to detect a fresh install. func (s *Store) CountUsers(ctx context.Context) (int, error) { var n int if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil { return 0, fmt.Errorf("store: count users: %w", err) } return n, nil } // CountEnabledAdmins returns the number of users with role='admin' // AND disabled_at IS NULL. Used by the last-admin guard before // disable / role-demote operations. func (s *Store) CountEnabledAdmins(ctx context.Context) (int, error) { var n int if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users WHERE role = 'admin' AND disabled_at IS NULL`, ).Scan(&n); err != nil { return 0, fmt.Errorf("store: count admins: %w", err) } return n, nil } // MarkUserLogin records a successful authentication. func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET last_login_at = ? WHERE id = ?`, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: mark login: %w", err) } return nil } // SetUserEmail replaces the email field. Empty string clears it. func (s *Store) SetUserEmail(ctx context.Context, id, email string) error { em := strings.ToLower(strings.TrimSpace(email)) var v any if em == "" { v = nil } else { v = em } _, err := s.db.ExecContext(ctx, `UPDATE users SET email = ? WHERE id = ?`, v, id) if err != nil { return fmt.Errorf("store: set user email: %w", err) } return nil } // SetUserRole changes a user's role. func (s *Store) SetUserRole(ctx context.Context, id string, role Role) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET role = ? WHERE id = ?`, string(role), id) if err != nil { return fmt.Errorf("store: set user role: %w", err) } return nil } // DisableUser sets disabled_at = when. Idempotent on already-disabled // rows (no-op). func (s *Store) DisableUser(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET disabled_at = ? WHERE id = ? AND disabled_at IS NULL`, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: disable user: %w", err) } return nil } // EnableUser clears disabled_at. func (s *Store) EnableUser(ctx context.Context, id string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET disabled_at = NULL WHERE id = ?`, id) if err != nil { return fmt.Errorf("store: enable user: %w", err) } return nil } // SetMustChangePassword toggles the must_change_password flag. func (s *Store) SetMustChangePassword(ctx context.Context, id string, must bool) error { v := 0 if must { v = 1 } _, err := s.db.ExecContext(ctx, `UPDATE users SET must_change_password = ? WHERE id = ?`, v, id) if err != nil { return fmt.Errorf("store: set must_change_password: %w", err) } return nil } // SetPasswordHash stores a new password_hash and clears the // must_change_password flag in one go. func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error { _, err := s.db.ExecContext(ctx, `UPDATE users SET password_hash = ?, must_change_password = 0 WHERE id = ?`, hash, id) if err != nil { return fmt.Errorf("store: set password: %w", err) } return nil } func scanUser(scan func(...any) error) (*User, error) { var u User var role string var email, disabledAt, lastLogin sql.NullString var must int var created string if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role, &email, &disabledAt, &must, &created, &lastLogin); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("store: scan user: %w", err) } u.Role = Role(role) if email.Valid { v := email.String u.Email = &v } if disabledAt.Valid { t, _ := time.Parse(time.RFC3339Nano, disabledAt.String) u.DisabledAt = &t } u.MustChangePassword = must == 1 t, _ := time.Parse(time.RFC3339Nano, created) u.CreatedAt = t if lastLogin.Valid { t, _ := time.Parse(time.RFC3339Nano, lastLogin.String) u.LastLoginAt = &t } return &u, nil }