store: GetUserByOIDCSubject + scanUser auth_source/oidc_subject

This commit is contained in:
2026-05-05 13:12:11 +01:00
parent 154b57a4cd
commit 70aa22e87e
2 changed files with 99 additions and 14 deletions
+51 -14
View File
@@ -18,12 +18,18 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
if u.MustChangePassword {
must = 1
}
authSource := u.AuthSource
if authSource == "" {
authSource = "local"
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, username, password_hash, role, email,
must_change_password, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
must_change_password, auth_source,
oidc_subject, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
u.ID, u.Username, u.PasswordHash, string(u.Role),
nullable(u.Email), must,
nullable(u.Email), must, authSource,
nullable(u.OIDCSubject),
u.CreatedAt.UTC().Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("store: create user: %w", err)
@@ -31,24 +37,49 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
return nil
}
// userSelectCols centralises the column list every read path uses so
// scanUser stays in lockstep.
const userSelectCols = `id, username, password_hash, role, email,
disabled_at, must_change_password,
auth_source, oidc_subject,
created_at, last_login_at`
// GetUserByUsername resolves a user case-insensitively.
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, email, disabled_at,
must_change_password, created_at, last_login_at
FROM users WHERE LOWER(username) = LOWER(?)`, username)
`SELECT `+userSelectCols+` FROM users WHERE LOWER(username) = LOWER(?)`,
username)
return scanUser(row.Scan)
}
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, email, disabled_at,
must_change_password, created_at, last_login_at
FROM users WHERE id = ?`, id)
`SELECT `+userSelectCols+` FROM users WHERE id = ?`, id)
return scanUser(row.Scan)
}
// GetUserByOIDCSubject finds the user JIT-provisioned on a previous
// OIDC sign-in. ErrNotFound on miss.
func (s *Store) GetUserByOIDCSubject(ctx context.Context, sub string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT `+userSelectCols+` FROM users WHERE oidc_subject = ?`, sub)
return scanUser(row.Scan)
}
// SetUserOIDCSubject pins an existing user row to an IdP subject.
// Used by tests today; reserved for a future "link a local user to
// OIDC" flow.
func (s *Store) SetUserOIDCSubject(ctx context.Context, id, authSource, sub string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET auth_source = ?, oidc_subject = ? WHERE id = ?`,
authSource, sub, id)
if err != nil {
return fmt.Errorf("store: set oidc subject: %w", err)
}
return nil
}
// UserSort selects the column ListUsers orders by. OrderBy is
// allowlisted in usersOrderColumn so callers can't inject SQL via
// this field. Empty / unknown OrderBy falls back to "username".
@@ -88,9 +119,8 @@ func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) {
// Default: username ASC (alphabetical), matching pre-sort behaviour.
asc = true
}
q := `SELECT id, username, password_hash, role, email, disabled_at,
must_change_password, created_at, last_login_at
FROM users ORDER BY ` + usersOrderColumn(sort.OrderBy, asc)
q := `SELECT ` + userSelectCols + ` FROM users ORDER BY ` +
usersOrderColumn(sort.OrderBy, asc)
rows, err := s.db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("store: list users: %w", err)
@@ -220,11 +250,13 @@ func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error {
func scanUser(scan func(...any) error) (*User, error) {
var u User
var role string
var email, disabledAt, lastLogin sql.NullString
var email, disabledAt, oidcSub, lastLogin sql.NullString
var must int
var authSource string
var created string
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
&email, &disabledAt, &must, &created, &lastLogin); err != nil {
&email, &disabledAt, &must, &authSource, &oidcSub,
&created, &lastLogin); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
@@ -240,6 +272,11 @@ func scanUser(scan func(...any) error) (*User, error) {
u.DisabledAt = &t
}
u.MustChangePassword = must == 1
u.AuthSource = authSource
if oidcSub.Valid {
v := oidcSub.String
u.OIDCSubject = &v
}
t, _ := time.Parse(time.RFC3339Nano, created)
u.CreatedAt = t
if lastLogin.Valid {