195 lines
5.9 KiB
Go
195 lines
5.9 KiB
Go
// Package oidc wraps go-oidc + oauth2 in the small surface the
|
|
// HTTP handlers need: discovery, code-exchange config, ID-token
|
|
// verification, and role-claim resolution.
|
|
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
|
)
|
|
|
|
// Client bundles the discovered provider + a pre-built oauth2.Config.
|
|
// Constructed once at server start; safe for concurrent use.
|
|
type Client struct {
|
|
cfg *config.OIDCConfig
|
|
provider *gooidc.Provider
|
|
verifier *gooidc.IDTokenVerifier
|
|
oauth *oauth2.Config
|
|
endSession string // discovered end_session_endpoint, "" if none
|
|
}
|
|
|
|
// New discovers the provider's well-known config and builds a Client.
|
|
// Network call — should be invoked once at startup with a context
|
|
// carrying a sane timeout. Returns an error on a 4xx/5xx from
|
|
// discovery so the operator finds out at startup, not on first login.
|
|
func New(ctx context.Context, cfg *config.OIDCConfig, baseURL string) (*Client, error) {
|
|
if cfg == nil {
|
|
return nil, errors.New("oidc: config nil")
|
|
}
|
|
prov, err := gooidc.NewProvider(ctx, cfg.Issuer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: discovery: %w", err)
|
|
}
|
|
redir := cfg.RedirectURL
|
|
if redir == "" {
|
|
redir = strings.TrimRight(baseURL, "/") + "/auth/oidc/callback"
|
|
}
|
|
oa := &oauth2.Config{
|
|
ClientID: cfg.ClientID,
|
|
ClientSecret: cfg.ClientSecret,
|
|
Endpoint: prov.Endpoint(),
|
|
RedirectURL: redir,
|
|
Scopes: cfg.Scopes,
|
|
}
|
|
verifier := prov.Verifier(&gooidc.Config{ClientID: cfg.ClientID})
|
|
|
|
// Pull end_session_endpoint out of the discovery doc — go-oidc
|
|
// doesn't expose it as a typed field, but the underlying claims
|
|
// blob does.
|
|
var doc struct {
|
|
EndSessionEndpoint string `json:"end_session_endpoint"`
|
|
}
|
|
_ = prov.Claims(&doc)
|
|
|
|
return &Client{
|
|
cfg: cfg,
|
|
provider: prov,
|
|
verifier: verifier,
|
|
oauth: oa,
|
|
endSession: doc.EndSessionEndpoint,
|
|
}, nil
|
|
}
|
|
|
|
// AuthURL returns the URL to redirect the browser to for the
|
|
// Authorization Code + PKCE flow. State + verifier are caller-
|
|
// supplied so the caller can persist them in the oidc_state table.
|
|
func (c *Client) AuthURL(state, codeChallenge string) string {
|
|
return c.oauth.AuthCodeURL(state,
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
}
|
|
|
|
// Exchange swaps a code+verifier for a token set and verifies the
|
|
// id_token. Returns the parsed Claims and the raw id_token (the
|
|
// caller stashes the raw on the session for RP-initiated logout).
|
|
func (c *Client) Exchange(ctx context.Context, code, verifier string) (*Claims, string, error) {
|
|
tok, err := c.oauth.Exchange(ctx, code,
|
|
oauth2.SetAuthURLParam("code_verifier", verifier))
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("oidc: token exchange: %w", err)
|
|
}
|
|
rawID, ok := tok.Extra("id_token").(string)
|
|
if !ok || rawID == "" {
|
|
return nil, "", errors.New("oidc: id_token missing from token response")
|
|
}
|
|
idTok, err := c.verifier.Verify(ctx, rawID)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("oidc: verify id_token: %w", err)
|
|
}
|
|
var raw map[string]any
|
|
if err := idTok.Claims(&raw); err != nil {
|
|
return nil, "", fmt.Errorf("oidc: claims: %w", err)
|
|
}
|
|
return parseClaims(raw, c.cfg.RoleClaim), rawID, nil
|
|
}
|
|
|
|
// EndSessionEndpoint exposes the discovered end_session URL ("" if
|
|
// the IdP doesn't advertise one).
|
|
func (c *Client) EndSessionEndpoint() string { return c.endSession }
|
|
|
|
// DisplayName for the SSO button on the login page.
|
|
func (c *Client) DisplayName() string { return c.cfg.DisplayName }
|
|
|
|
// MapRole returns the role for the first matching claim value; "" if
|
|
// none match. Caller treats "" as deny.
|
|
func (c *Client) MapRole(roles []string) string {
|
|
for _, r := range roles {
|
|
if mapped, ok := c.cfg.RoleMapping[r]; ok {
|
|
return mapped
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Claims is the minimal projection the callback handler cares about.
|
|
type Claims struct {
|
|
Subject string
|
|
PreferredUsername string
|
|
Email string
|
|
Roles []string // normalised from string|[]string|csv
|
|
}
|
|
|
|
// parseClaims pulls the four fields we need from the raw id_token
|
|
// claims. The 'roles' field is normalised from the three shapes
|
|
// IdPs emit (string, []string, comma-separated string).
|
|
func parseClaims(raw map[string]any, roleClaim string) *Claims {
|
|
c := &Claims{}
|
|
if v, ok := raw["sub"].(string); ok {
|
|
c.Subject = v
|
|
}
|
|
if v, ok := raw["preferred_username"].(string); ok {
|
|
c.PreferredUsername = v
|
|
}
|
|
if v, ok := raw["email"].(string); ok {
|
|
c.Email = v
|
|
}
|
|
switch v := raw[roleClaim].(type) {
|
|
case string:
|
|
for _, p := range strings.Split(v, ",") {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
c.Roles = append(c.Roles, p)
|
|
}
|
|
}
|
|
case []any:
|
|
for _, item := range v {
|
|
if s, ok := item.(string); ok && s != "" {
|
|
c.Roles = append(c.Roles, s)
|
|
}
|
|
}
|
|
}
|
|
return c
|
|
}
|
|
|
|
// RandomState generates 32 random bytes URL-safe base64-encoded —
|
|
// used as the 'state' parameter on the authorization request.
|
|
// Caller is expected to compute sha256(state) for storage.
|
|
func RandomState() (string, error) {
|
|
var b [32]byte
|
|
if _, err := rand.Read(b[:]); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
|
}
|
|
|
|
// PKCEPair generates a code_verifier (base64-url 64 chars) and the
|
|
// corresponding S256 code_challenge.
|
|
func PKCEPair() (verifier, challenge string, err error) {
|
|
var b [48]byte
|
|
if _, err := rand.Read(b[:]); err != nil {
|
|
return "", "", err
|
|
}
|
|
verifier = base64.RawURLEncoding.EncodeToString(b[:])
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
|
return verifier, challenge, nil
|
|
}
|
|
|
|
// HashState returns sha256(state) hex — used as the primary key in
|
|
// the oidc_state table (so a DB leak doesn't leak active states).
|
|
func HashState(state string) string {
|
|
sum := sha256.Sum256([]byte(state))
|
|
return fmt.Sprintf("%x", sum)
|
|
}
|