oidc: client wrapper around go-oidc — discovery, exchange, claim parse
This commit is contained in:
@@ -0,0 +1,194 @@
|
||||
// Package oidc wraps go-oidc + oauth2 in the small surface the
|
||||
// HTTP handlers need: discovery, code-exchange config, ID-token
|
||||
// verification, and role-claim resolution.
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
)
|
||||
|
||||
// Client bundles the discovered provider + a pre-built oauth2.Config.
|
||||
// Constructed once at server start; safe for concurrent use.
|
||||
type Client struct {
|
||||
cfg *config.OIDCConfig
|
||||
provider *gooidc.Provider
|
||||
verifier *gooidc.IDTokenVerifier
|
||||
oauth *oauth2.Config
|
||||
endSession string // discovered end_session_endpoint, "" if none
|
||||
}
|
||||
|
||||
// New discovers the provider's well-known config and builds a Client.
|
||||
// Network call — should be invoked once at startup with a context
|
||||
// carrying a sane timeout. Returns an error on a 4xx/5xx from
|
||||
// discovery so the operator finds out at startup, not on first login.
|
||||
func New(ctx context.Context, cfg *config.OIDCConfig, baseURL string) (*Client, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("oidc: config nil")
|
||||
}
|
||||
prov, err := gooidc.NewProvider(ctx, cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: discovery: %w", err)
|
||||
}
|
||||
redir := cfg.RedirectURL
|
||||
if redir == "" {
|
||||
redir = strings.TrimRight(baseURL, "/") + "/auth/oidc/callback"
|
||||
}
|
||||
oa := &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: prov.Endpoint(),
|
||||
RedirectURL: redir,
|
||||
Scopes: cfg.Scopes,
|
||||
}
|
||||
verifier := prov.Verifier(&gooidc.Config{ClientID: cfg.ClientID})
|
||||
|
||||
// Pull end_session_endpoint out of the discovery doc — go-oidc
|
||||
// doesn't expose it as a typed field, but the underlying claims
|
||||
// blob does.
|
||||
var doc struct {
|
||||
EndSessionEndpoint string `json:"end_session_endpoint"`
|
||||
}
|
||||
_ = prov.Claims(&doc)
|
||||
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
provider: prov,
|
||||
verifier: verifier,
|
||||
oauth: oa,
|
||||
endSession: doc.EndSessionEndpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthURL returns the URL to redirect the browser to for the
|
||||
// Authorization Code + PKCE flow. State + verifier are caller-
|
||||
// supplied so the caller can persist them in the oidc_state table.
|
||||
func (c *Client) AuthURL(state, codeChallenge string) string {
|
||||
return c.oauth.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
}
|
||||
|
||||
// Exchange swaps a code+verifier for a token set and verifies the
|
||||
// id_token. Returns the parsed Claims and the raw id_token (the
|
||||
// caller stashes the raw on the session for RP-initiated logout).
|
||||
func (c *Client) Exchange(ctx context.Context, code, verifier string) (*Claims, string, error) {
|
||||
tok, err := c.oauth.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("oidc: token exchange: %w", err)
|
||||
}
|
||||
rawID, ok := tok.Extra("id_token").(string)
|
||||
if !ok || rawID == "" {
|
||||
return nil, "", errors.New("oidc: id_token missing from token response")
|
||||
}
|
||||
idTok, err := c.verifier.Verify(ctx, rawID)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("oidc: verify id_token: %w", err)
|
||||
}
|
||||
var raw map[string]any
|
||||
if err := idTok.Claims(&raw); err != nil {
|
||||
return nil, "", fmt.Errorf("oidc: claims: %w", err)
|
||||
}
|
||||
return parseClaims(raw, c.cfg.RoleClaim), rawID, nil
|
||||
}
|
||||
|
||||
// EndSessionEndpoint exposes the discovered end_session URL ("" if
|
||||
// the IdP doesn't advertise one).
|
||||
func (c *Client) EndSessionEndpoint() string { return c.endSession }
|
||||
|
||||
// DisplayName for the SSO button on the login page.
|
||||
func (c *Client) DisplayName() string { return c.cfg.DisplayName }
|
||||
|
||||
// MapRole returns the role for the first matching claim value; "" if
|
||||
// none match. Caller treats "" as deny.
|
||||
func (c *Client) MapRole(roles []string) string {
|
||||
for _, r := range roles {
|
||||
if mapped, ok := c.cfg.RoleMapping[r]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Claims is the minimal projection the callback handler cares about.
|
||||
type Claims struct {
|
||||
Subject string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
Roles []string // normalised from string|[]string|csv
|
||||
}
|
||||
|
||||
// parseClaims pulls the four fields we need from the raw id_token
|
||||
// claims. The 'roles' field is normalised from the three shapes
|
||||
// IdPs emit (string, []string, comma-separated string).
|
||||
func parseClaims(raw map[string]any, roleClaim string) *Claims {
|
||||
c := &Claims{}
|
||||
if v, ok := raw["sub"].(string); ok {
|
||||
c.Subject = v
|
||||
}
|
||||
if v, ok := raw["preferred_username"].(string); ok {
|
||||
c.PreferredUsername = v
|
||||
}
|
||||
if v, ok := raw["email"].(string); ok {
|
||||
c.Email = v
|
||||
}
|
||||
switch v := raw[roleClaim].(type) {
|
||||
case string:
|
||||
for _, p := range strings.Split(v, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
c.Roles = append(c.Roles, p)
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
c.Roles = append(c.Roles, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// RandomState generates 32 random bytes URL-safe base64-encoded —
|
||||
// used as the 'state' parameter on the authorization request.
|
||||
// Caller is expected to compute sha256(state) for storage.
|
||||
func RandomState() (string, error) {
|
||||
var b [32]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
||||
}
|
||||
|
||||
// PKCEPair generates a code_verifier (base64-url 64 chars) and the
|
||||
// corresponding S256 code_challenge.
|
||||
func PKCEPair() (verifier, challenge string, err error) {
|
||||
var b [48]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b[:])
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// HashState returns sha256(state) hex — used as the primary key in
|
||||
// the oidc_state table (so a DB leak doesn't leak active states).
|
||||
func HashState(state string) string {
|
||||
sum := sha256.Sum256([]byte(state))
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
Reference in New Issue
Block a user