Files

288 lines
8.8 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
// CreateUser inserts a row. Username is lowercase-normalised so the
// case-insensitive unique index from migration 0017 doesn't surprise
// callers who insert 'Alice' and look up 'alice'.
func (s *Store) CreateUser(ctx context.Context, u User) error {
u.Username = strings.ToLower(strings.TrimSpace(u.Username))
must := 0
if u.MustChangePassword {
must = 1
}
authSource := u.AuthSource
if authSource == "" {
authSource = "local"
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, username, password_hash, role, email,
must_change_password, auth_source,
oidc_subject, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
u.ID, u.Username, u.PasswordHash, string(u.Role),
nullable(u.Email), must, authSource,
nullable(u.OIDCSubject),
u.CreatedAt.UTC().Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("store: create user: %w", err)
}
return nil
}
// userSelectCols centralises the column list every read path uses so
// scanUser stays in lockstep.
const userSelectCols = `id, username, password_hash, role, email,
disabled_at, must_change_password,
auth_source, oidc_subject,
created_at, last_login_at`
// GetUserByUsername resolves a user case-insensitively.
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT `+userSelectCols+` FROM users WHERE LOWER(username) = LOWER(?)`,
username)
return scanUser(row.Scan)
}
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT `+userSelectCols+` FROM users WHERE id = ?`, id)
return scanUser(row.Scan)
}
// GetUserByOIDCSubject finds the user JIT-provisioned on a previous
// OIDC sign-in. ErrNotFound on miss.
func (s *Store) GetUserByOIDCSubject(ctx context.Context, sub string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT `+userSelectCols+` FROM users WHERE oidc_subject = ?`, sub)
return scanUser(row.Scan)
}
// SetUserOIDCSubject pins an existing user row to an IdP subject.
// Used by tests today; reserved for a future "link a local user to
// OIDC" flow.
func (s *Store) SetUserOIDCSubject(ctx context.Context, id, authSource, sub string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET auth_source = ?, oidc_subject = ? WHERE id = ?`,
authSource, sub, id)
if err != nil {
return fmt.Errorf("store: set oidc subject: %w", err)
}
return nil
}
// UserSort selects the column ListUsers orders by. OrderBy is
// allowlisted in usersOrderColumn so callers can't inject SQL via
// this field. Empty / unknown OrderBy falls back to "username".
type UserSort struct {
OrderBy string // "username" | "email" | "role" | "last_login_at"
OrderAsc bool // false = DESC; true = ASC
}
// usersOrderColumn validates s.OrderBy and returns the SQL fragment.
// last_login_at gets a NULL-tail trick so users who've never logged
// in sort to the bottom regardless of asc/desc — matches operator
// intuition ("show me real activity" not "show me NULLs first").
func usersOrderColumn(col string, asc bool) string {
dir := "DESC"
if asc {
dir = "ASC"
}
switch col {
case "email":
return fmt.Sprintf("email IS NULL, email %s, username", dir)
case "role":
return fmt.Sprintf("role %s, username", dir)
case "last_login_at":
return fmt.Sprintf("last_login_at IS NULL, last_login_at %s, username", dir)
default: // username (and unknown)
return fmt.Sprintf("username %s", dir)
}
}
// ListUsers returns users sorted per UserSort. Default (zero value)
// is username ASC. Used by the user-management page (sort headers)
// and by surfaces that need a user-id → username map (audit log
// filter, "ack'd by" projections) — those callers pass UserSort{}.
func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) {
asc := sort.OrderAsc
if sort.OrderBy == "" {
// Default: username ASC (alphabetical), matching pre-sort behaviour.
asc = true
}
q := `SELECT ` + userSelectCols + ` FROM users ORDER BY ` +
usersOrderColumn(sort.OrderBy, asc)
rows, err := s.db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("store: list users: %w", err)
}
defer func() { _ = rows.Close() }()
var out []User
for rows.Next() {
u, err := scanUser(rows.Scan)
if err != nil {
return nil, err
}
out = append(out, *u)
}
return out, rows.Err()
}
// CountUsers returns the total number of user rows. The first-run
// bootstrap uses this to detect a fresh install.
func (s *Store) CountUsers(ctx context.Context) (int, error) {
var n int
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil {
return 0, fmt.Errorf("store: count users: %w", err)
}
return n, nil
}
// CountEnabledAdmins returns the number of users with role='admin'
// AND disabled_at IS NULL. Used by the last-admin guard before
// disable / role-demote operations.
func (s *Store) CountEnabledAdmins(ctx context.Context) (int, error) {
var n int
if err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM users WHERE role = 'admin' AND disabled_at IS NULL`,
).Scan(&n); err != nil {
return 0, fmt.Errorf("store: count admins: %w", err)
}
return n, nil
}
// MarkUserLogin records a successful authentication.
func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET last_login_at = ? WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: mark login: %w", err)
}
return nil
}
// SetUserEmail replaces the email field. Empty string clears it.
func (s *Store) SetUserEmail(ctx context.Context, id, email string) error {
em := strings.ToLower(strings.TrimSpace(email))
var v any
if em == "" {
v = nil
} else {
v = em
}
_, err := s.db.ExecContext(ctx,
`UPDATE users SET email = ? WHERE id = ?`, v, id)
if err != nil {
return fmt.Errorf("store: set user email: %w", err)
}
return nil
}
// SetUserRole changes a user's role.
func (s *Store) SetUserRole(ctx context.Context, id string, role Role) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET role = ? WHERE id = ?`, string(role), id)
if err != nil {
return fmt.Errorf("store: set user role: %w", err)
}
return nil
}
// DisableUser sets disabled_at = when. Idempotent on already-disabled
// rows (no-op).
func (s *Store) DisableUser(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET disabled_at = ?
WHERE id = ? AND disabled_at IS NULL`,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: disable user: %w", err)
}
return nil
}
// EnableUser clears disabled_at.
func (s *Store) EnableUser(ctx context.Context, id string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET disabled_at = NULL WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("store: enable user: %w", err)
}
return nil
}
// SetMustChangePassword toggles the must_change_password flag.
func (s *Store) SetMustChangePassword(ctx context.Context, id string, must bool) error {
v := 0
if must {
v = 1
}
_, err := s.db.ExecContext(ctx,
`UPDATE users SET must_change_password = ? WHERE id = ?`, v, id)
if err != nil {
return fmt.Errorf("store: set must_change_password: %w", err)
}
return nil
}
// SetPasswordHash stores a new password_hash and clears the
// must_change_password flag in one go.
func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET password_hash = ?, must_change_password = 0 WHERE id = ?`,
hash, id)
if err != nil {
return fmt.Errorf("store: set password: %w", err)
}
return nil
}
func scanUser(scan func(...any) error) (*User, error) {
var u User
var role string
var email, disabledAt, oidcSub, lastLogin sql.NullString
var must int
var authSource string
var created string
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
&email, &disabledAt, &must, &authSource, &oidcSub,
&created, &lastLogin); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("store: scan user: %w", err)
}
u.Role = Role(role)
if email.Valid {
v := email.String
u.Email = &v
}
if disabledAt.Valid {
t, _ := time.Parse(time.RFC3339Nano, disabledAt.String)
u.DisabledAt = &t
}
u.MustChangePassword = must == 1
u.AuthSource = authSource
if oidcSub.Valid {
v := oidcSub.String
u.OIDCSubject = &v
}
t, _ := time.Parse(time.RFC3339Nano, created)
u.CreatedAt = t
if lastLogin.Valid {
t, _ := time.Parse(time.RFC3339Nano, lastLogin.String)
u.LastLoginAt = &t
}
return &u, nil
}