P4-05: OIDC login (generic, JIT-provisioned) #16
@@ -0,0 +1,65 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PutOIDCState stores the (state_hash, code_verifier) pair created
|
||||
// at /auth/oidc/login start. Called once per login attempt.
|
||||
func (s *Store) PutOIDCState(ctx context.Context, stateHash, verifier string, createdAt time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO oidc_state (state_hash, code_verifier, created_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
stateHash, verifier,
|
||||
createdAt.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: put oidc state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumeOIDCState atomically reads + deletes the row in one go,
|
||||
// returning the code_verifier. Single-use — a re-play returns
|
||||
// ErrNotFound. Used by the OIDC callback handler.
|
||||
func (s *Store) ConsumeOIDCState(ctx context.Context, stateHash string) (string, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("store: begin: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
var verifier string
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`SELECT code_verifier FROM oidc_state WHERE state_hash = ?`,
|
||||
stateHash).Scan(&verifier)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", fmt.Errorf("store: consume oidc state: %w", err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`DELETE FROM oidc_state WHERE state_hash = ?`, stateHash); err != nil {
|
||||
return "", fmt.Errorf("store: delete oidc state: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", fmt.Errorf("store: commit: %w", err)
|
||||
}
|
||||
return verifier, nil
|
||||
}
|
||||
|
||||
// CleanupExpiredOIDCState removes entries created before cutoff.
|
||||
// Called on the alert engine's 60s tick alongside setup-token sweep.
|
||||
func (s *Store) CleanupExpiredOIDCState(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM oidc_state WHERE created_at < ?`,
|
||||
cutoff.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: cleanup oidc state: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newOIDCStateTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
st, err := Open(context.Background(), filepath.Join(t.TempDir(), "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
return st
|
||||
}
|
||||
|
||||
func TestOIDCStatePutAndConsume(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newOIDCStateTestStore(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
if err := st.PutOIDCState(ctx, "hash1", "verifier-1", now); err != nil {
|
||||
t.Fatalf("put: %v", err)
|
||||
}
|
||||
v, err := st.ConsumeOIDCState(ctx, "hash1")
|
||||
if err != nil {
|
||||
t.Fatalf("consume: %v", err)
|
||||
}
|
||||
if v != "verifier-1" {
|
||||
t.Errorf("verifier: got %q want %q", v, "verifier-1")
|
||||
}
|
||||
if _, err := st.ConsumeOIDCState(ctx, "hash1"); err == nil {
|
||||
t.Error("re-consume should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCStateCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newOIDCStateTestStore(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
_ = st.PutOIDCState(ctx, "stale", "v-stale", now.Add(-10*time.Minute))
|
||||
_ = st.PutOIDCState(ctx, "fresh", "v-fresh", now)
|
||||
|
||||
cutoff := now.Add(-5 * time.Minute)
|
||||
n, err := st.CleanupExpiredOIDCState(ctx, cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("cleanup count: got %d want 1", n)
|
||||
}
|
||||
if _, err := st.ConsumeOIDCState(ctx, "stale"); err == nil {
|
||||
t.Error("stale entry should have been deleted")
|
||||
}
|
||||
if _, err := st.ConsumeOIDCState(ctx, "fresh"); err != nil {
|
||||
t.Errorf("fresh entry should still be readable: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user