From 7f8bd13a07ae952bd222fe88e6f1173d127a9d13 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Tue, 5 May 2026 13:14:27 +0100 Subject: [PATCH] store: round-trip IDToken on sessions for RP-initiated logout --- internal/store/sessions.go | 16 ++++++++++------ internal/store/sessions_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/internal/store/sessions.go b/internal/store/sessions.go index a2ef31c..b02e90d 100644 --- a/internal/store/sessions.go +++ b/internal/store/sessions.go @@ -12,13 +12,14 @@ import ( // insert; the raw token is what the caller hands to the user (cookie). func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error { _, err := s.db.ExecContext(ctx, - `INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua) - VALUES (?, ?, ?, ?, ?, ?)`, + `INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua, id_token) + VALUES (?, ?, ?, ?, ?, ?, ?)`, tokenHash, sess.UserID, sess.CreatedAt.UTC().Format(time.RFC3339Nano), sess.ExpiresAt.UTC().Format(time.RFC3339Nano), - sess.IP, sess.UA) + nullableStr(sess.IP), nullableStr(sess.UA), + nullableStr(sess.IDToken)) if err != nil { return fmt.Errorf("store: create session: %w", err) } @@ -32,15 +33,15 @@ func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash strin // of valid token hashes. func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) { row := s.db.QueryRowContext(ctx, - `SELECT id, user_id, created_at, expires_at, ip, ua + `SELECT id, user_id, created_at, expires_at, ip, ua, id_token FROM sessions WHERE id = ? AND expires_at > ?`, tokenHash, time.Now().UTC().Format(time.RFC3339Nano)) var sess Session var created, expires string - var ip, ua sql.NullString - if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil { + var ip, ua, idTok sql.NullString + if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua, &idTok); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } @@ -62,6 +63,9 @@ func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, if ua.Valid { sess.UA = ua.String } + if idTok.Valid { + sess.IDToken = idTok.String + } return &sess, nil } diff --git a/internal/store/sessions_test.go b/internal/store/sessions_test.go index 81222ee..0dcd553 100644 --- a/internal/store/sessions_test.go +++ b/internal/store/sessions_test.go @@ -43,3 +43,34 @@ func TestDeleteSessionsByUserID(t *testing.T) { t.Error("hash1 should be gone") } } + +func TestSessionRoundTripsIDToken(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + uid := "u-oidc" + if err := s.CreateUser(ctx, User{ + ID: uid, Username: "ouser", PasswordHash: "", + Role: RoleOperator, CreatedAt: now, + AuthSource: "oidc", + }); err != nil { + t.Fatalf("create user: %v", err) + } + + if err := s.CreateSession(ctx, Session{ + ID: "h1", UserID: uid, CreatedAt: now, + ExpiresAt: now.Add(time.Hour), + IDToken: "eyJ.fake.jwt", + }, "h1"); err != nil { + t.Fatalf("create session: %v", err) + } + got, err := s.LookupSession(ctx, "h1") + if err != nil { + t.Fatalf("lookup: %v", err) + } + if got.IDToken != "eyJ.fake.jwt" { + t.Errorf("id_token round trip: got %q", got.IDToken) + } +}