diff --git a/internal/store/users.go b/internal/store/users.go index f414e92..ed0ddb6 100644 --- a/internal/store/users.go +++ b/internal/store/users.go @@ -18,12 +18,18 @@ func (s *Store) CreateUser(ctx context.Context, u User) error { if u.MustChangePassword { must = 1 } + authSource := u.AuthSource + if authSource == "" { + authSource = "local" + } _, err := s.db.ExecContext(ctx, `INSERT INTO users (id, username, password_hash, role, email, - must_change_password, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?)`, + must_change_password, auth_source, + oidc_subject, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u.ID, u.Username, u.PasswordHash, string(u.Role), - nullable(u.Email), must, + nullable(u.Email), must, authSource, + nullable(u.OIDCSubject), u.CreatedAt.UTC().Format(time.RFC3339Nano)) if err != nil { return fmt.Errorf("store: create user: %w", err) @@ -31,24 +37,49 @@ func (s *Store) CreateUser(ctx context.Context, u User) error { return nil } +// userSelectCols centralises the column list every read path uses so +// scanUser stays in lockstep. +const userSelectCols = `id, username, password_hash, role, email, + disabled_at, must_change_password, + auth_source, oidc_subject, + created_at, last_login_at` + // GetUserByUsername resolves a user case-insensitively. func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, username, password_hash, role, email, disabled_at, - must_change_password, created_at, last_login_at - FROM users WHERE LOWER(username) = LOWER(?)`, username) + `SELECT `+userSelectCols+` FROM users WHERE LOWER(username) = LOWER(?)`, + username) return scanUser(row.Scan) } // GetUserByID looks up a user by id. Returns ErrNotFound on miss. func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, username, password_hash, role, email, disabled_at, - must_change_password, created_at, last_login_at - FROM users WHERE id = ?`, id) + `SELECT `+userSelectCols+` FROM users WHERE id = ?`, id) return scanUser(row.Scan) } +// GetUserByOIDCSubject finds the user JIT-provisioned on a previous +// OIDC sign-in. ErrNotFound on miss. +func (s *Store) GetUserByOIDCSubject(ctx context.Context, sub string) (*User, error) { + row := s.db.QueryRowContext(ctx, + `SELECT `+userSelectCols+` FROM users WHERE oidc_subject = ?`, sub) + return scanUser(row.Scan) +} + +// SetUserOIDCSubject pins an existing user row to an IdP subject. +// Used by tests today; reserved for a future "link a local user to +// OIDC" flow. +func (s *Store) SetUserOIDCSubject(ctx context.Context, id, authSource, sub string) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET auth_source = ?, oidc_subject = ? WHERE id = ?`, + authSource, sub, id) + if err != nil { + return fmt.Errorf("store: set oidc subject: %w", err) + } + return nil +} + // UserSort selects the column ListUsers orders by. OrderBy is // allowlisted in usersOrderColumn so callers can't inject SQL via // this field. Empty / unknown OrderBy falls back to "username". @@ -88,9 +119,8 @@ func (s *Store) ListUsers(ctx context.Context, sort UserSort) ([]User, error) { // Default: username ASC (alphabetical), matching pre-sort behaviour. asc = true } - q := `SELECT id, username, password_hash, role, email, disabled_at, - must_change_password, created_at, last_login_at - FROM users ORDER BY ` + usersOrderColumn(sort.OrderBy, asc) + q := `SELECT ` + userSelectCols + ` FROM users ORDER BY ` + + usersOrderColumn(sort.OrderBy, asc) rows, err := s.db.QueryContext(ctx, q) if err != nil { return nil, fmt.Errorf("store: list users: %w", err) @@ -220,11 +250,13 @@ func (s *Store) SetPasswordHash(ctx context.Context, id, hash string) error { func scanUser(scan func(...any) error) (*User, error) { var u User var role string - var email, disabledAt, lastLogin sql.NullString + var email, disabledAt, oidcSub, lastLogin sql.NullString var must int + var authSource string var created string if err := scan(&u.ID, &u.Username, &u.PasswordHash, &role, - &email, &disabledAt, &must, &created, &lastLogin); err != nil { + &email, &disabledAt, &must, &authSource, &oidcSub, + &created, &lastLogin); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } @@ -240,6 +272,11 @@ func scanUser(scan func(...any) error) (*User, error) { u.DisabledAt = &t } u.MustChangePassword = must == 1 + u.AuthSource = authSource + if oidcSub.Valid { + v := oidcSub.String + u.OIDCSubject = &v + } t, _ := time.Parse(time.RFC3339Nano, created) u.CreatedAt = t if lastLogin.Valid { diff --git a/internal/store/users_test.go b/internal/store/users_test.go index a7684a9..ce4679b 100644 --- a/internal/store/users_test.go +++ b/internal/store/users_test.go @@ -165,6 +165,54 @@ func TestCreateUserLowercasesUsername(t *testing.T) { } } +func TestGetUserByOIDCSubject(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + sub := "sub-abc-123" + + if err := s.CreateUser(ctx, User{ + ID: "u1", Username: "alice", PasswordHash: "", + Role: RoleAdmin, CreatedAt: now, + AuthSource: "oidc", OIDCSubject: &sub, + }); err != nil { + t.Fatalf("create: %v", err) + } + got, err := s.GetUserByOIDCSubject(ctx, sub) + if err != nil { + t.Fatalf("get by sub: %v", err) + } + if got.ID != "u1" || got.AuthSource != "oidc" { + t.Errorf("unexpected: %+v", got) + } + if _, err := s.GetUserByOIDCSubject(ctx, "nope"); !errors.Is(err, ErrNotFound) { + t.Errorf("missing sub: want ErrNotFound, got %v", err) + } +} + +func TestSetUserOIDCSubject(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + if err := s.CreateUser(ctx, User{ + ID: "u1", Username: "alice", PasswordHash: "x", + Role: RoleAdmin, CreatedAt: now, + }); err != nil { + t.Fatalf("create: %v", err) + } + sub := "sub-456" + if err := s.SetUserOIDCSubject(ctx, "u1", "oidc", sub); err != nil { + t.Fatalf("set: %v", err) + } + got, _ := s.GetUserByID(ctx, "u1") + if got.AuthSource != "oidc" || got.OIDCSubject == nil || *got.OIDCSubject != sub { + t.Errorf("after set: %+v", got) + } +} + func TestEnrollmentTokenSingleUse(t *testing.T) { t.Parallel() s := openTestStore(t)