From e2976a42e631a06c0daa3a3fc93ab227e65f614a Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Tue, 5 May 2026 13:15:45 +0100 Subject: [PATCH] store: oidc_state CRUD + 5-minute cleanup --- internal/store/oidc_state.go | 65 +++++++++++++++++++++++++++++++ internal/store/oidc_state_test.go | 64 ++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+) create mode 100644 internal/store/oidc_state.go create mode 100644 internal/store/oidc_state_test.go diff --git a/internal/store/oidc_state.go b/internal/store/oidc_state.go new file mode 100644 index 0000000..cb6a5c9 --- /dev/null +++ b/internal/store/oidc_state.go @@ -0,0 +1,65 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" +) + +// PutOIDCState stores the (state_hash, code_verifier) pair created +// at /auth/oidc/login start. Called once per login attempt. +func (s *Store) PutOIDCState(ctx context.Context, stateHash, verifier string, createdAt time.Time) error { + _, err := s.db.ExecContext(ctx, + `INSERT INTO oidc_state (state_hash, code_verifier, created_at) + VALUES (?, ?, ?)`, + stateHash, verifier, + createdAt.UTC().Format(time.RFC3339Nano)) + if err != nil { + return fmt.Errorf("store: put oidc state: %w", err) + } + return nil +} + +// ConsumeOIDCState atomically reads + deletes the row in one go, +// returning the code_verifier. Single-use — a re-play returns +// ErrNotFound. Used by the OIDC callback handler. +func (s *Store) ConsumeOIDCState(ctx context.Context, stateHash string) (string, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return "", fmt.Errorf("store: begin: %w", err) + } + defer func() { _ = tx.Rollback() }() + var verifier string + err = tx.QueryRowContext(ctx, + `SELECT code_verifier FROM oidc_state WHERE state_hash = ?`, + stateHash).Scan(&verifier) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNotFound + } + return "", fmt.Errorf("store: consume oidc state: %w", err) + } + if _, err := tx.ExecContext(ctx, + `DELETE FROM oidc_state WHERE state_hash = ?`, stateHash); err != nil { + return "", fmt.Errorf("store: delete oidc state: %w", err) + } + if err := tx.Commit(); err != nil { + return "", fmt.Errorf("store: commit: %w", err) + } + return verifier, nil +} + +// CleanupExpiredOIDCState removes entries created before cutoff. +// Called on the alert engine's 60s tick alongside setup-token sweep. +func (s *Store) CleanupExpiredOIDCState(ctx context.Context, cutoff time.Time) (int64, error) { + res, err := s.db.ExecContext(ctx, + `DELETE FROM oidc_state WHERE created_at < ?`, + cutoff.UTC().Format(time.RFC3339Nano)) + if err != nil { + return 0, fmt.Errorf("store: cleanup oidc state: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} diff --git a/internal/store/oidc_state_test.go b/internal/store/oidc_state_test.go new file mode 100644 index 0000000..a28b176 --- /dev/null +++ b/internal/store/oidc_state_test.go @@ -0,0 +1,64 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + "time" +) + +func newOIDCStateTestStore(t *testing.T) *Store { + t.Helper() + st, err := Open(context.Background(), filepath.Join(t.TempDir(), "rm.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + return st +} + +func TestOIDCStatePutAndConsume(t *testing.T) { + t.Parallel() + st := newOIDCStateTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + if err := st.PutOIDCState(ctx, "hash1", "verifier-1", now); err != nil { + t.Fatalf("put: %v", err) + } + v, err := st.ConsumeOIDCState(ctx, "hash1") + if err != nil { + t.Fatalf("consume: %v", err) + } + if v != "verifier-1" { + t.Errorf("verifier: got %q want %q", v, "verifier-1") + } + if _, err := st.ConsumeOIDCState(ctx, "hash1"); err == nil { + t.Error("re-consume should fail") + } +} + +func TestOIDCStateCleanup(t *testing.T) { + t.Parallel() + st := newOIDCStateTestStore(t) + ctx := context.Background() + now := time.Now().UTC() + + _ = st.PutOIDCState(ctx, "stale", "v-stale", now.Add(-10*time.Minute)) + _ = st.PutOIDCState(ctx, "fresh", "v-fresh", now) + + cutoff := now.Add(-5 * time.Minute) + n, err := st.CleanupExpiredOIDCState(ctx, cutoff) + if err != nil { + t.Fatalf("cleanup: %v", err) + } + if n != 1 { + t.Errorf("cleanup count: got %d want 1", n) + } + if _, err := st.ConsumeOIDCState(ctx, "stale"); err == nil { + t.Error("stale entry should have been deleted") + } + if _, err := st.ConsumeOIDCState(ctx, "fresh"); err != nil { + t.Errorf("fresh entry should still be readable: %v", err) + } +}