206 lines
6.8 KiB
Go
206 lines
6.8 KiB
Go
// oidc_handlers.go — OIDC sign-in handlers. Public routes when oidc
|
|
// is configured (s.deps.OIDC != nil), otherwise not mounted.
|
|
package http
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"log/slog"
|
|
stdhttp "net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/oklog/ulid/v2"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
|
)
|
|
|
|
// handleOIDCLogin generates state + PKCE pair, persists them, and
|
|
// redirects to the IdP authorization endpoint.
|
|
func (s *Server) handleOIDCLogin(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
|
state, err := oidc.RandomState()
|
|
if err != nil {
|
|
slog.Error("oidc login: state", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
verifier, challenge, err := oidc.PKCEPair()
|
|
if err != nil {
|
|
slog.Error("oidc login: pkce", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
if err := s.deps.Store.PutOIDCState(r.Context(),
|
|
oidc.HashState(state), verifier, time.Now().UTC()); err != nil {
|
|
slog.Error("oidc login: persist state", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
stdhttp.Redirect(w, r, s.deps.OIDC.AuthURL(state, challenge), stdhttp.StatusSeeOther)
|
|
}
|
|
|
|
func (s *Server) handleOIDCCallback(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
|
q := r.URL.Query()
|
|
code := q.Get("code")
|
|
state := q.Get("state")
|
|
if code == "" || state == "" {
|
|
s.oidcRedirectError(w, r, "missing_params")
|
|
return
|
|
}
|
|
verifier, err := s.deps.Store.ConsumeOIDCState(r.Context(), oidc.HashState(state))
|
|
if err != nil {
|
|
s.oidcRedirectError(w, r, "bad_state")
|
|
return
|
|
}
|
|
claims, rawIDToken, err := s.deps.OIDC.Exchange(r.Context(), code, verifier)
|
|
if err != nil {
|
|
slog.Warn("oidc callback: exchange", "err", err)
|
|
s.oidcRedirectError(w, r, "exchange_failed")
|
|
return
|
|
}
|
|
|
|
uname := strings.ToLower(strings.TrimSpace(claims.PreferredUsername))
|
|
if uname == "" {
|
|
uname = strings.ToLower(strings.TrimSpace(claims.Email))
|
|
}
|
|
if uname == "" || claims.Subject == "" {
|
|
s.oidcRedirectError(w, r, "missing_claims")
|
|
return
|
|
}
|
|
|
|
role := s.deps.OIDC.MapRole(claims.Roles)
|
|
if role == "" {
|
|
_ = s.auditOIDCBlocked(r, claims, "no_role_match")
|
|
s.oidcRedirectError(w, r, "no_role_match")
|
|
return
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
// Returning OIDC user — refresh role + email + last_login.
|
|
existing, err := s.deps.Store.GetUserByOIDCSubject(r.Context(), claims.Subject)
|
|
if err == nil {
|
|
if existing.DisabledAt != nil {
|
|
s.oidcRedirectError(w, r, "user_disabled")
|
|
return
|
|
}
|
|
_ = s.deps.Store.SetUserRole(r.Context(), existing.ID, store.Role(role))
|
|
_ = s.deps.Store.SetUserEmail(r.Context(), existing.ID, claims.Email)
|
|
_ = s.deps.Store.MarkUserLogin(r.Context(), existing.ID, now)
|
|
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
|
ID: ulid.Make().String(), UserID: &existing.ID, Actor: "user",
|
|
Action: "user.oidc_login", TargetKind: ptr("user"),
|
|
TargetID: &existing.ID, TS: now,
|
|
})
|
|
s.oidcDropSessionAndRedirect(w, r, existing.ID, rawIDToken, now)
|
|
return
|
|
} else if !errors.Is(err, store.ErrNotFound) {
|
|
slog.Error("oidc callback: lookup by sub", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// New OIDC user — first check the username doesn't collide with
|
|
// a local user.
|
|
if _, err := s.deps.Store.GetUserByUsername(r.Context(), uname); err == nil {
|
|
_ = s.auditOIDCBlocked(r, claims, "username_taken")
|
|
s.oidcRedirectError(w, r, "username_taken")
|
|
return
|
|
} else if !errors.Is(err, store.ErrNotFound) {
|
|
slog.Error("oidc callback: lookup by username", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// JIT-provision.
|
|
id := ulid.Make().String()
|
|
var emailPtr *string
|
|
if claims.Email != "" {
|
|
em := strings.ToLower(claims.Email)
|
|
emailPtr = &em
|
|
}
|
|
sub := claims.Subject
|
|
if err := s.deps.Store.CreateUser(r.Context(), store.User{
|
|
ID: id, Username: uname, PasswordHash: "",
|
|
Role: store.Role(role), Email: emailPtr,
|
|
AuthSource: "oidc", OIDCSubject: &sub,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
slog.Error("oidc callback: provision", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
_ = s.deps.Store.MarkUserLogin(r.Context(), id, now)
|
|
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
|
ID: ulid.Make().String(), UserID: &id, Actor: "user",
|
|
Action: "user.created", TargetKind: ptr("user"), TargetID: &id,
|
|
TS: now,
|
|
Payload: jsonMust(map[string]any{"auth_source": "oidc"}),
|
|
})
|
|
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
|
ID: ulid.Make().String(), UserID: &id, Actor: "user",
|
|
Action: "user.oidc_login", TargetKind: ptr("user"), TargetID: &id,
|
|
TS: now,
|
|
})
|
|
s.oidcDropSessionAndRedirect(w, r, id, rawIDToken, now)
|
|
}
|
|
|
|
func (s *Server) oidcDropSessionAndRedirect(w stdhttp.ResponseWriter, r *stdhttp.Request, userID, idToken string, now time.Time) {
|
|
rawSession, err := auth.NewToken()
|
|
if err != nil {
|
|
slog.Error("oidc: session token", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
hashed := auth.HashToken(rawSession)
|
|
if err := s.deps.Store.CreateSession(r.Context(), store.Session{
|
|
ID: hashed, UserID: userID, CreatedAt: now,
|
|
ExpiresAt: now.Add(8 * time.Hour),
|
|
IDToken: idToken,
|
|
}, hashed); err != nil {
|
|
slog.Error("oidc: create session", "err", err)
|
|
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
|
return
|
|
}
|
|
stdhttp.SetCookie(w, &stdhttp.Cookie{
|
|
Name: sessionCookieName, Value: rawSession,
|
|
Path: "/", HttpOnly: true,
|
|
SameSite: stdhttp.SameSiteLaxMode,
|
|
Secure: s.deps.Cfg.CookieSecure,
|
|
Expires: now.Add(8 * time.Hour),
|
|
})
|
|
stdhttp.Redirect(w, r, "/", stdhttp.StatusSeeOther)
|
|
}
|
|
|
|
func (s *Server) oidcRedirectError(w stdhttp.ResponseWriter, r *stdhttp.Request, code string) {
|
|
stdhttp.Redirect(w, r, "/login?oidc_error="+code, stdhttp.StatusSeeOther)
|
|
}
|
|
|
|
// auditOIDCBlocked records a failed sign-in. user_id is nil because
|
|
// no row was created; the IdP subject + reason go in the payload so
|
|
// admin can correlate.
|
|
func (s *Server) auditOIDCBlocked(r *stdhttp.Request, claims *oidc.Claims, reason string) error {
|
|
return s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
|
ID: ulid.Make().String(), UserID: nil, Actor: "system",
|
|
Action: "user.oidc_login_blocked", TargetKind: ptr("user"),
|
|
TargetID: nil, TS: time.Now().UTC(),
|
|
Payload: jsonMust(map[string]any{
|
|
"sub": claims.Subject,
|
|
"username": claims.PreferredUsername,
|
|
"reason": reason,
|
|
}),
|
|
})
|
|
}
|
|
|
|
// jsonMust marshals to json.RawMessage; on error returns nil so the
|
|
// audit row still lands without the payload (best-effort).
|
|
func jsonMust(v any) json.RawMessage {
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return json.RawMessage(b)
|
|
}
|