66 lines
2.0 KiB
Go
66 lines
2.0 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
// PutOIDCState stores the (state_hash, code_verifier) pair created
|
|
// at /auth/oidc/login start. Called once per login attempt.
|
|
func (s *Store) PutOIDCState(ctx context.Context, stateHash, verifier string, createdAt time.Time) error {
|
|
_, err := s.db.ExecContext(ctx,
|
|
`INSERT INTO oidc_state (state_hash, code_verifier, created_at)
|
|
VALUES (?, ?, ?)`,
|
|
stateHash, verifier,
|
|
createdAt.UTC().Format(time.RFC3339Nano))
|
|
if err != nil {
|
|
return fmt.Errorf("store: put oidc state: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ConsumeOIDCState atomically reads + deletes the row in one go,
|
|
// returning the code_verifier. Single-use — a re-play returns
|
|
// ErrNotFound. Used by the OIDC callback handler.
|
|
func (s *Store) ConsumeOIDCState(ctx context.Context, stateHash string) (string, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("store: begin: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
var verifier string
|
|
err = tx.QueryRowContext(ctx,
|
|
`SELECT code_verifier FROM oidc_state WHERE state_hash = ?`,
|
|
stateHash).Scan(&verifier)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return "", ErrNotFound
|
|
}
|
|
return "", fmt.Errorf("store: consume oidc state: %w", err)
|
|
}
|
|
if _, err := tx.ExecContext(ctx,
|
|
`DELETE FROM oidc_state WHERE state_hash = ?`, stateHash); err != nil {
|
|
return "", fmt.Errorf("store: delete oidc state: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return "", fmt.Errorf("store: commit: %w", err)
|
|
}
|
|
return verifier, nil
|
|
}
|
|
|
|
// CleanupExpiredOIDCState removes entries created before cutoff.
|
|
// Called on the alert engine's 60s tick alongside setup-token sweep.
|
|
func (s *Store) CleanupExpiredOIDCState(ctx context.Context, cutoff time.Time) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM oidc_state WHERE created_at < ?`,
|
|
cutoff.UTC().Format(time.RFC3339Nano))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("store: cleanup oidc state: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return n, nil
|
|
}
|