store: GetUserByOIDCSubject + scanUser auth_source/oidc_subject
This commit is contained in:
+51
-14
@@ -18,12 +18,18 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
|
|||||||
if u.MustChangePassword {
|
if u.MustChangePassword {
|
||||||
must = 1
|
must = 1
|
||||||
}
|
}
|
||||||
|
authSource := u.AuthSource
|
||||||
|
if authSource == "" {
|
||||||
|
authSource = "local"
|
||||||
|
}
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
`INSERT INTO users (id, username, password_hash, role, email,
|
`INSERT INTO users (id, username, password_hash, role, email,
|
||||||
must_change_password, created_at)
|
must_change_password, auth_source,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
oidc_subject, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
u.ID, u.Username, u.PasswordHash, string(u.Role),
|
u.ID, u.Username, u.PasswordHash, string(u.Role),
|
||||||
nullable(u.Email), must,
|
nullable(u.Email), must, authSource,
|
||||||
|
nullable(u.OIDCSubject),
|
||||||
u.CreatedAt.UTC().Format(time.RFC3339Nano))
|
u.CreatedAt.UTC().Format(time.RFC3339Nano))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("store: create user: %w", err)
|
return fmt.Errorf("store: create user: %w", err)
|
||||||
@@ -31,24 +37,49 @@ func (s *Store) CreateUser(ctx context.Context, u User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userSelectCols centralises the column list every read path uses so
|
||||||
|
// scanUser stays in lockstep.
|
||||||
|
const userSelectCols = `id, username, password_hash, role, email,
|
||||||
|
disabled_at, must_change_password,
|
||||||
|
auth_source, oidc_subject,
|
||||||
|
created_at, last_login_at`
|
||||||
|
|
||||||
// GetUserByUsername resolves a user case-insensitively.
|
// GetUserByUsername resolves a user case-insensitively.
|
||||||
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||||
row := s.db.QueryRowContext(ctx,
|
row := s.db.QueryRowContext(ctx,
|
||||||
`SELECT id, username, password_hash, role, email, disabled_at,
|
`SELECT `+userSelectCols+` FROM users WHERE LOWER(username) = LOWER(?)`,
|
||||||
must_change_password, created_at, last_login_at
|
username)
|
||||||
FROM users WHERE LOWER(username) = LOWER(?)`, username)
|
|
||||||
return scanUser(row.Scan)
|
return scanUser(row.Scan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
|
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
|
||||||
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
row := s.db.QueryRowContext(ctx,
|
row := s.db.QueryRowContext(ctx,
|
||||||
`SELECT id, username, password_hash, role, email, disabled_at,
|
`SELECT `+userSelectCols+` FROM users WHERE id = ?`, id)
|
||||||
must_change_password, created_at, last_login_at
|
|
||||||
FROM users WHERE id = ?`, id)
|
|
||||||
return scanUser(row.Scan)
|
return scanUser(row.Scan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserByOIDCSubject finds the user JIT-provisioned on a previous
|
||||||
|
// OIDC sign-in. ErrNotFound on miss.
|
||||||
|
func (s *Store) GetUserByOIDCSubject(ctx context.Context, sub string) (*User, error) {
|
||||||
|
row := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT `+userSelectCols+` FROM users WHERE oidc_subject = ?`, sub)
|
||||||
|
return scanUser(row.Scan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserOIDCSubject pins an existing user row to an IdP subject.
|
||||||
|
// Used by tests today; reserved for a future "link a local user to
|
||||||
|
// OIDC" flow.
|
||||||
|
func (s *Store) SetUserOIDCSubject(ctx context.Context, id, authSource, sub string) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE users SET auth_source = ?, oidc_subject = ? WHERE id = ?`,
|
||||||
|
authSource, sub, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set oidc subject: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// UserSort selects the column ListUsers orders by. OrderBy is
|
// UserSort selects the column ListUsers orders by. OrderBy is
|
||||||
// allowlisted in usersOrderColumn so callers can't inject SQL via
|
// allowlisted in usersOrderColumn so callers can't inject SQL via
|
||||||
// this field. Empty / unknown OrderBy falls back to "username".
|
// this field. Empty / unknown OrderBy falls back to "username".
|
||||||
@@ -88,9 +119,8 @@ func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) {
|
|||||||
// Default: username ASC (alphabetical), matching pre-sort behaviour.
|
// Default: username ASC (alphabetical), matching pre-sort behaviour.
|
||||||
asc = true
|
asc = true
|
||||||
}
|
}
|
||||||
q := `SELECT id, username, password_hash, role, email, disabled_at,
|
q := `SELECT ` + userSelectCols + ` FROM users ORDER BY ` +
|
||||||
must_change_password, created_at, last_login_at
|
usersOrderColumn(sort.OrderBy, asc)
|
||||||
FROM users ORDER BY ` + usersOrderColumn(sort.OrderBy, asc)
|
|
||||||
rows, err := s.db.QueryContext(ctx, q)
|
rows, err := s.db.QueryContext(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("store: list users: %w", err)
|
return nil, fmt.Errorf("store: list users: %w", err)
|
||||||
@@ -220,11 +250,13 @@ func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error {
|
|||||||
func scanUser(scan func(...any) error) (*User, error) {
|
func scanUser(scan func(...any) error) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
var role string
|
var role string
|
||||||
var email, disabledAt, lastLogin sql.NullString
|
var email, disabledAt, oidcSub, lastLogin sql.NullString
|
||||||
var must int
|
var must int
|
||||||
|
var authSource string
|
||||||
var created string
|
var created string
|
||||||
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
|
if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role,
|
||||||
&email, &disabledAt, &must, &created, &lastLogin); err != nil {
|
&email, &disabledAt, &must, &authSource, &oidcSub,
|
||||||
|
&created, &lastLogin); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
@@ -240,6 +272,11 @@ func scanUser(scan func(...any) error) (*User, error) {
|
|||||||
u.DisabledAt = &t
|
u.DisabledAt = &t
|
||||||
}
|
}
|
||||||
u.MustChangePassword = must == 1
|
u.MustChangePassword = must == 1
|
||||||
|
u.AuthSource = authSource
|
||||||
|
if oidcSub.Valid {
|
||||||
|
v := oidcSub.String
|
||||||
|
u.OIDCSubject = &v
|
||||||
|
}
|
||||||
t, _ := time.Parse(time.RFC3339Nano, created)
|
t, _ := time.Parse(time.RFC3339Nano, created)
|
||||||
u.CreatedAt = t
|
u.CreatedAt = t
|
||||||
if lastLogin.Valid {
|
if lastLogin.Valid {
|
||||||
|
|||||||
@@ -165,6 +165,54 @@ func TestCreateUserLowercasesUsername(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetUserByOIDCSubject(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
sub := "sub-abc-123"
|
||||||
|
|
||||||
|
if err := s.CreateUser(ctx, User{
|
||||||
|
ID: "u1", Username: "alice", PasswordHash: "",
|
||||||
|
Role: RoleAdmin, CreatedAt: now,
|
||||||
|
AuthSource: "oidc", OIDCSubject: &sub,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
got, err := s.GetUserByOIDCSubject(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get by sub: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != "u1" || got.AuthSource != "oidc" {
|
||||||
|
t.Errorf("unexpected: %+v", got)
|
||||||
|
}
|
||||||
|
if _, err := s.GetUserByOIDCSubject(ctx, "nope"); !errors.Is(err, ErrNotFound) {
|
||||||
|
t.Errorf("missing sub: want ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUserOIDCSubject(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
if err := s.CreateUser(ctx, User{
|
||||||
|
ID: "u1", Username: "alice", PasswordHash: "x",
|
||||||
|
Role: RoleAdmin, CreatedAt: now,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
sub := "sub-456"
|
||||||
|
if err := s.SetUserOIDCSubject(ctx, "u1", "oidc", sub); err != nil {
|
||||||
|
t.Fatalf("set: %v", err)
|
||||||
|
}
|
||||||
|
got, _ := s.GetUserByID(ctx, "u1")
|
||||||
|
if got.AuthSource != "oidc" || got.OIDCSubject == nil || *got.OIDCSubject != sub {
|
||||||
|
t.Errorf("after set: %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
s := openTestStore(t)
|
s := openTestStore(t)
|
||||||
|
|||||||
Reference in New Issue
Block a user