store: round-trip IDToken on sessions for RP-initiated logout
This commit is contained in:
@@ -12,13 +12,14 @@ import (
|
|||||||
// insert; the raw token is what the caller hands to the user (cookie).
|
// insert; the raw token is what the caller hands to the user (cookie).
|
||||||
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
|
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua)
|
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua, id_token)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
tokenHash,
|
tokenHash,
|
||||||
sess.UserID,
|
sess.UserID,
|
||||||
sess.CreatedAt.UTC().Format(time.RFC3339Nano),
|
sess.CreatedAt.UTC().Format(time.RFC3339Nano),
|
||||||
sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
|
sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
|
||||||
sess.IP, sess.UA)
|
nullableStr(sess.IP), nullableStr(sess.UA),
|
||||||
|
nullableStr(sess.IDToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("store: create session: %w", err)
|
return fmt.Errorf("store: create session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -32,15 +33,15 @@ func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash strin
|
|||||||
// of valid token hashes.
|
// of valid token hashes.
|
||||||
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
|
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
|
||||||
row := s.db.QueryRowContext(ctx,
|
row := s.db.QueryRowContext(ctx,
|
||||||
`SELECT id, user_id, created_at, expires_at, ip, ua
|
`SELECT id, user_id, created_at, expires_at, ip, ua, id_token
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE id = ? AND expires_at > ?`,
|
WHERE id = ? AND expires_at > ?`,
|
||||||
tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
|
tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
|
||||||
|
|
||||||
var sess Session
|
var sess Session
|
||||||
var created, expires string
|
var created, expires string
|
||||||
var ip, ua sql.NullString
|
var ip, ua, idTok sql.NullString
|
||||||
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil {
|
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua, &idTok); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
@@ -62,6 +63,9 @@ func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session,
|
|||||||
if ua.Valid {
|
if ua.Valid {
|
||||||
sess.UA = ua.String
|
sess.UA = ua.String
|
||||||
}
|
}
|
||||||
|
if idTok.Valid {
|
||||||
|
sess.IDToken = idTok.String
|
||||||
|
}
|
||||||
return &sess, nil
|
return &sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,3 +43,34 @@ func TestDeleteSessionsByUserID(t *testing.T) {
|
|||||||
t.Error("hash1 should be gone")
|
t.Error("hash1 should be gone")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSessionRoundTripsIDToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
uid := "u-oidc"
|
||||||
|
if err := s.CreateUser(ctx, User{
|
||||||
|
ID: uid, Username: "ouser", PasswordHash: "",
|
||||||
|
Role: RoleOperator, CreatedAt: now,
|
||||||
|
AuthSource: "oidc",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateSession(ctx, Session{
|
||||||
|
ID: "h1", UserID: uid, CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(time.Hour),
|
||||||
|
IDToken: "eyJ.fake.jwt",
|
||||||
|
}, "h1"); err != nil {
|
||||||
|
t.Fatalf("create session: %v", err)
|
||||||
|
}
|
||||||
|
got, err := s.LookupSession(ctx, "h1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("lookup: %v", err)
|
||||||
|
}
|
||||||
|
if got.IDToken != "eyJ.fake.jwt" {
|
||||||
|
t.Errorf("id_token round trip: got %q", got.IDToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user