store: GetUserByOIDCSubject + scanUser auth_source/oidc_subject

This commit is contained in:
2026-05-05 13:12:11 +01:00
parent 154b57a4cd
commit 70aa22e87e
2 changed files with 99 additions and 14 deletions
+51 -14
View File
@@ -18,12 +18,18 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
if u.MustChangePassword { if u.MustChangePassword {
must = 1 must = 1
} }
authSource := u.AuthSource
if authSource == "" {
authSource = "local"
}
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, username, password_hash, role, email, `INSERT INTO users (id, username, password_hash, role, email,
must_change_password, created_at) must_change_password, auth_source,
VALUES (?, ?, ?, ?, ?, ?, ?)`, oidc_subject, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
u.ID, u.Username, u.PasswordHash, string(u.Role), u.ID, u.Username, u.PasswordHash, string(u.Role),
nullable(u.Email), must, nullable(u.Email), must, authSource,
nullable(u.OIDCSubject),
u.CreatedAt.UTC().Format(time.RFC3339Nano)) u.CreatedAt.UTC().Format(time.RFC3339Nano))
if err != nil { if err != nil {
return fmt.Errorf("store: create user: %w", err) return fmt.Errorf("store: create user: %w", err)
@@ -31,24 +37,49 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
return nil return nil
} }
// userSelectCols centralises the column list every read path uses so
// scanUser stays in lockstep.
const userSelectCols = `id, username, password_hash, role, email,
disabled_at, must_change_password,
auth_source, oidc_subject,
created_at, last_login_at`
// GetUserByUsername resolves a user case-insensitively. // GetUserByUsername resolves a user case-insensitively.
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx, row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, email, disabled_at, `SELECT `+userSelectCols+` FROM users WHERE LOWER(username) = LOWER(?)`,
must_change_password, created_at, last_login_at username)
FROM users WHERE LOWER(username) = LOWER(?)`, username)
return scanUser(row.Scan) return scanUser(row.Scan)
} }
// GetUserByID looks up a user by id. Returns ErrNotFound on miss. // GetUserByID looks up a user by id. Returns ErrNotFound on miss.
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
row := s.db.QueryRowContext(ctx, row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, email, disabled_at, `SELECT `+userSelectCols+` FROM users WHERE id = ?`, id)
must_change_password, created_at, last_login_at
FROM users WHERE id = ?`, id)
return scanUser(row.Scan) return scanUser(row.Scan)
} }
// GetUserByOIDCSubject finds the user JIT-provisioned on a previous
// OIDC sign-in. ErrNotFound on miss.
func (s *Store) GetUserByOIDCSubject(ctx context.Context, sub string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT `+userSelectCols+` FROM users WHERE oidc_subject = ?`, sub)
return scanUser(row.Scan)
}
// SetUserOIDCSubject pins an existing user row to an IdP subject.
// Used by tests today; reserved for a future "link a local user to
// OIDC" flow.
func (s *Store) SetUserOIDCSubject(ctx context.Context, id, authSource, sub string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET auth_source = ?, oidc_subject = ? WHERE id = ?`,
authSource, sub, id)
if err != nil {
return fmt.Errorf("store: set oidc subject: %w", err)
}
return nil
}
// UserSort selects the column ListUsers orders by. OrderBy is // UserSort selects the column ListUsers orders by. OrderBy is
// allowlisted in usersOrderColumn so callers can't inject SQL via // allowlisted in usersOrderColumn so callers can't inject SQL via
// this field. Empty / unknown OrderBy falls back to "username". // this field. Empty / unknown OrderBy falls back to "username".
@@ -88,9 +119,8 @@ func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) {
// Default: username ASC (alphabetical), matching pre-sort behaviour. // Default: username ASC (alphabetical), matching pre-sort behaviour.
asc = true asc = true
} }
q := `SELECT id, username, password_hash, role, email, disabled_at, q := `SELECT ` + userSelectCols + ` FROM users ORDER BY ` +
must_change_password, created_at, last_login_at usersOrderColumn(sort.OrderBy, asc)
FROM users ORDER BY ` + usersOrderColumn(sort.OrderBy, asc)
rows, err := s.db.QueryContext(ctx, q) rows, err := s.db.QueryContext(ctx, q)
if err != nil { if err != nil {
return nil, fmt.Errorf("store: list users: %w", err) return nil, fmt.Errorf("store: list users: %w", err)
@@ -220,11 +250,13 @@ func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error {
func scanUser(scan func(...any) error) (*User, error) { func scanUser(scan func(...any) error) (*User, error) {
var u User var u User
var role string var role string
var email, disabledAt, lastLogin sql.NullString var email, disabledAt, oidcSub, lastLogin sql.NullString
var must int var must int
var authSource string
var created string var created string
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role, if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
&email, &disabledAt, &must, &created, &lastLogin); err != nil { &email, &disabledAt, &must, &authSource, &oidcSub,
&created, &lastLogin); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} }
@@ -240,6 +272,11 @@ func scanUser(scan func(...any) error) (*User, error) {
u.DisabledAt = &t u.DisabledAt = &t
} }
u.MustChangePassword = must == 1 u.MustChangePassword = must == 1
u.AuthSource = authSource
if oidcSub.Valid {
v := oidcSub.String
u.OIDCSubject = &v
}
t, _ := time.Parse(time.RFC3339Nano, created) t, _ := time.Parse(time.RFC3339Nano, created)
u.CreatedAt = t u.CreatedAt = t
if lastLogin.Valid { if lastLogin.Valid {
+48
View File
@@ -165,6 +165,54 @@ func TestCreateUserLowercasesUsername(t *testing.T) {
} }
} }
func TestGetUserByOIDCSubject(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
now := time.Now().UTC()
sub := "sub-abc-123"
if err := s.CreateUser(ctx, User{
ID: "u1", Username: "alice", PasswordHash: "",
Role: RoleAdmin, CreatedAt: now,
AuthSource: "oidc", OIDCSubject: &sub,
}); err != nil {
t.Fatalf("create: %v", err)
}
got, err := s.GetUserByOIDCSubject(ctx, sub)
if err != nil {
t.Fatalf("get by sub: %v", err)
}
if got.ID != "u1" || got.AuthSource != "oidc" {
t.Errorf("unexpected: %+v", got)
}
if _, err := s.GetUserByOIDCSubject(ctx, "nope"); !errors.Is(err, ErrNotFound) {
t.Errorf("missing sub: want ErrNotFound, got %v", err)
}
}
func TestSetUserOIDCSubject(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
now := time.Now().UTC()
if err := s.CreateUser(ctx, User{
ID: "u1", Username: "alice", PasswordHash: "x",
Role: RoleAdmin, CreatedAt: now,
}); err != nil {
t.Fatalf("create: %v", err)
}
sub := "sub-456"
if err := s.SetUserOIDCSubject(ctx, "u1", "oidc", sub); err != nil {
t.Fatalf("set: %v", err)
}
got, _ := s.GetUserByID(ctx, "u1")
if got.AuthSource != "oidc" || got.OIDCSubject == nil || *got.OIDCSubject != sub {
t.Errorf("after set: %+v", got)
}
}
func TestEnrollmentTokenSingleUse(t *testing.T) { func TestEnrollmentTokenSingleUse(t *testing.T) {
t.Parallel() t.Parallel()
s := openTestStore(t) s := openTestStore(t)