store: round-trip IDToken on sessions for RP-initiated logout

This commit is contained in:
2026-05-05 13:14:27 +01:00
parent 805380f52d
commit 7f8bd13a07
2 changed files with 41 additions and 6 deletions
+10 -6
View File
@@ -12,13 +12,14 @@ import (
// insert; the raw token is what the caller hands to the user (cookie). // insert; the raw token is what the caller hands to the user (cookie).
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error { func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
_, err := s.db.ExecContext(ctx, _, err := s.db.ExecContext(ctx,
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua) `INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua, id_token)
VALUES (?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?)`,
tokenHash, tokenHash,
sess.UserID, sess.UserID,
sess.CreatedAt.UTC().Format(time.RFC3339Nano), sess.CreatedAt.UTC().Format(time.RFC3339Nano),
sess.ExpiresAt.UTC().Format(time.RFC3339Nano), sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
sess.IP, sess.UA) nullableStr(sess.IP), nullableStr(sess.UA),
nullableStr(sess.IDToken))
if err != nil { if err != nil {
return fmt.Errorf("store: create session: %w", err) return fmt.Errorf("store: create session: %w", err)
} }
@@ -32,15 +33,15 @@ func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash strin
// of valid token hashes. // of valid token hashes.
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) { func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
row := s.db.QueryRowContext(ctx, row := s.db.QueryRowContext(ctx,
`SELECT id, user_id, created_at, expires_at, ip, ua `SELECT id, user_id, created_at, expires_at, ip, ua, id_token
FROM sessions FROM sessions
WHERE id = ? AND expires_at > ?`, WHERE id = ? AND expires_at > ?`,
tokenHash, time.Now().UTC().Format(time.RFC3339Nano)) tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
var sess Session var sess Session
var created, expires string var created, expires string
var ip, ua sql.NullString var ip, ua, idTok sql.NullString
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil { if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua, &idTok); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} }
@@ -62,6 +63,9 @@ func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session,
if ua.Valid { if ua.Valid {
sess.UA = ua.String sess.UA = ua.String
} }
if idTok.Valid {
sess.IDToken = idTok.String
}
return &sess, nil return &sess, nil
} }
+31
View File
@@ -43,3 +43,34 @@ func TestDeleteSessionsByUserID(t *testing.T) {
t.Error("hash1 should be gone") t.Error("hash1 should be gone")
} }
} }
func TestSessionRoundTripsIDToken(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
now := time.Now().UTC()
uid := "u-oidc"
if err := s.CreateUser(ctx, User{
ID: uid, Username: "ouser", PasswordHash: "",
Role: RoleOperator, CreatedAt: now,
AuthSource: "oidc",
}); err != nil {
t.Fatalf("create user: %v", err)
}
if err := s.CreateSession(ctx, Session{
ID: "h1", UserID: uid, CreatedAt: now,
ExpiresAt: now.Add(time.Hour),
IDToken: "eyJ.fake.jwt",
}, "h1"); err != nil {
t.Fatalf("create session: %v", err)
}
got, err := s.LookupSession(ctx, "h1")
if err != nil {
t.Fatalf("lookup: %v", err)
}
if got.IDToken != "eyJ.fake.jwt" {
t.Errorf("id_token round trip: got %q", got.IDToken)
}
}