P4-05: OIDC login (generic, JIT-provisioned) #16

Merged
steve merged 19 commits from p4-05-oidc into main 2026-05-05 14:46:23 +01:00
3 changed files with 287 additions and 1 deletions
Showing only changes of commit c55a75355a - Show all commits
+170
View File
@@ -3,11 +3,18 @@
package http
import (
"encoding/json"
"errors"
"log/slog"
stdhttp "net/http"
"strings"
"time"
"github.com/oklog/ulid/v2"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// handleOIDCLogin generates state + PKCE pair, persists them, and
@@ -33,3 +40,166 @@ func (s *Server) handleOIDCLogin(w stdhttp.ResponseWriter, r *stdhttp.Request) {
}
stdhttp.Redirect(w, r, s.deps.OIDC.AuthURL(state, challenge), stdhttp.StatusSeeOther)
}
func (s *Server) handleOIDCCallback(w stdhttp.ResponseWriter, r *stdhttp.Request) {
q := r.URL.Query()
code := q.Get("code")
state := q.Get("state")
if code == "" || state == "" {
s.oidcRedirectError(w, r, "missing_params")
return
}
verifier, err := s.deps.Store.ConsumeOIDCState(r.Context(), oidc.HashState(state))
if err != nil {
s.oidcRedirectError(w, r, "bad_state")
return
}
claims, rawIDToken, err := s.deps.OIDC.Exchange(r.Context(), code, verifier)
if err != nil {
slog.Warn("oidc callback: exchange", "err", err)
s.oidcRedirectError(w, r, "exchange_failed")
return
}
uname := strings.ToLower(strings.TrimSpace(claims.PreferredUsername))
if uname == "" {
uname = strings.ToLower(strings.TrimSpace(claims.Email))
}
if uname == "" || claims.Subject == "" {
s.oidcRedirectError(w, r, "missing_claims")
return
}
role := s.deps.OIDC.MapRole(claims.Roles)
if role == "" {
_ = s.auditOIDCBlocked(r, claims, "no_role_match")
s.oidcRedirectError(w, r, "no_role_match")
return
}
now := time.Now().UTC()
// Returning OIDC user — refresh role + email + last_login.
existing, err := s.deps.Store.GetUserByOIDCSubject(r.Context(), claims.Subject)
if err == nil {
if existing.DisabledAt != nil {
s.oidcRedirectError(w, r, "user_disabled")
return
}
_ = s.deps.Store.SetUserRole(r.Context(), existing.ID, store.Role(role))
_ = s.deps.Store.SetUserEmail(r.Context(), existing.ID, claims.Email)
_ = s.deps.Store.MarkUserLogin(r.Context(), existing.ID, now)
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(), UserID: &existing.ID, Actor: "user",
Action: "user.oidc_login", TargetKind: ptr("user"),
TargetID: &existing.ID, TS: now,
})
s.oidcDropSessionAndRedirect(w, r, existing.ID, rawIDToken, now)
return
} else if !errors.Is(err, store.ErrNotFound) {
slog.Error("oidc callback: lookup by sub", "err", err)
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
return
}
// New OIDC user — first check the username doesn't collide with
// a local user.
if _, err := s.deps.Store.GetUserByUsername(r.Context(), uname); err == nil {
_ = s.auditOIDCBlocked(r, claims, "username_taken")
s.oidcRedirectError(w, r, "username_taken")
return
} else if !errors.Is(err, store.ErrNotFound) {
slog.Error("oidc callback: lookup by username", "err", err)
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
return
}
// JIT-provision.
id := ulid.Make().String()
var emailPtr *string
if claims.Email != "" {
em := strings.ToLower(claims.Email)
emailPtr = &em
}
sub := claims.Subject
if err := s.deps.Store.CreateUser(r.Context(), store.User{
ID: id, Username: uname, PasswordHash: "",
Role: store.Role(role), Email: emailPtr,
AuthSource: "oidc", OIDCSubject: &sub,
CreatedAt: now,
}); err != nil {
slog.Error("oidc callback: provision", "err", err)
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
return
}
_ = s.deps.Store.MarkUserLogin(r.Context(), id, now)
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(), UserID: &id, Actor: "user",
Action: "user.created", TargetKind: ptr("user"), TargetID: &id,
TS: now,
Payload: jsonMust(map[string]any{"auth_source": "oidc"}),
})
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(), UserID: &id, Actor: "user",
Action: "user.oidc_login", TargetKind: ptr("user"), TargetID: &id,
TS: now,
})
s.oidcDropSessionAndRedirect(w, r, id, rawIDToken, now)
}
func (s *Server) oidcDropSessionAndRedirect(w stdhttp.ResponseWriter, r *stdhttp.Request, userID, idToken string, now time.Time) {
rawSession, err := auth.NewToken()
if err != nil {
slog.Error("oidc: session token", "err", err)
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
return
}
hashed := auth.HashToken(rawSession)
if err := s.deps.Store.CreateSession(r.Context(), store.Session{
ID: hashed, UserID: userID, CreatedAt: now,
ExpiresAt: now.Add(8 * time.Hour),
IDToken: idToken,
}, hashed); err != nil {
slog.Error("oidc: create session", "err", err)
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
return
}
stdhttp.SetCookie(w, &stdhttp.Cookie{
Name: sessionCookieName, Value: rawSession,
Path: "/", HttpOnly: true,
SameSite: stdhttp.SameSiteLaxMode,
Secure: s.deps.Cfg.CookieSecure,
Expires: now.Add(8 * time.Hour),
})
stdhttp.Redirect(w, r, "/", stdhttp.StatusSeeOther)
}
func (s *Server) oidcRedirectError(w stdhttp.ResponseWriter, r *stdhttp.Request, code string) {
stdhttp.Redirect(w, r, "/login?oidc_error="+code, stdhttp.StatusSeeOther)
}
// auditOIDCBlocked records a failed sign-in. user_id is nil because
// no row was created; the IdP subject + reason go in the payload so
// admin can correlate.
func (s *Server) auditOIDCBlocked(r *stdhttp.Request, claims *oidc.Claims, reason string) error {
return s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
ID: ulid.Make().String(), UserID: nil, Actor: "system",
Action: "user.oidc_login_blocked", TargetKind: ptr("user"),
TargetID: nil, TS: time.Now().UTC(),
Payload: jsonMust(map[string]any{
"sub": claims.Subject,
"username": claims.PreferredUsername,
"reason": reason,
}),
})
}
// jsonMust marshals to json.RawMessage; on error returns nil so the
// audit row still lands without the payload (best-effort).
func jsonMust(v any) json.RawMessage {
b, err := json.Marshal(v)
if err != nil {
return nil
}
return json.RawMessage(b)
}
+116
View File
@@ -3,7 +3,9 @@ package http
import (
"context"
stdhttp "net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"path/filepath"
"strings"
"testing"
@@ -84,3 +86,117 @@ func TestOIDCLoginRedirectsToIdP(t *testing.T) {
}
_ = srv
}
// runCallback drives the auth code flow against the stub: kicks off
// /auth/oidc/login (capturing the state), mints a code at the stub
// with the given claims, then GETs /auth/oidc/callback. Returns the
// final response.
func runCallback(t *testing.T, ts *httptest.Server, stub *oidctest.StubIdP, claims map[string]any) *stdhttp.Response {
t.Helper()
jar, _ := cookiejar.New(nil)
c := &stdhttp.Client{Jar: jar, CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
return stdhttp.ErrUseLastResponse
}}
res, err := c.Get(ts.URL + "/auth/oidc/login")
if err != nil {
t.Fatalf("login: %v", err)
}
res.Body.Close()
authURL, _ := url.Parse(res.Header.Get("Location"))
state := authURL.Query().Get("state")
code := stub.MintCode(claims)
res, err = c.Get(ts.URL + "/auth/oidc/callback?code=" + code + "&state=" + state)
if err != nil {
t.Fatalf("callback: %v", err)
}
return res
}
func TestOIDCCallbackHappyPathAdmin(t *testing.T) {
t.Parallel()
srv, ts, stub := newTestServerWithOIDC(t)
res := runCallback(t, ts, stub, map[string]any{
"sub": "admin-sub",
"preferred_username": "alice",
"email": "alice@example.com",
"groups": []string{"rm-admins"},
"aud": "test-client",
})
defer res.Body.Close()
if res.StatusCode != stdhttp.StatusSeeOther || res.Header.Get("Location") != "/" {
t.Errorf("status: %d Location: %q", res.StatusCode, res.Header.Get("Location"))
}
u, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "admin-sub")
if err != nil || u.AuthSource != "oidc" || u.Role != "admin" || u.Username != "alice" {
t.Errorf("user: %+v err: %v", u, err)
}
}
func TestOIDCCallbackNoRoleMatchDeny(t *testing.T) {
t.Parallel()
_, ts, stub := newTestServerWithOIDC(t)
res := runCallback(t, ts, stub, map[string]any{
"sub": "other-sub",
"preferred_username": "bob",
"groups": []string{"something-else"},
"aud": "test-client",
})
defer res.Body.Close()
if res.StatusCode != stdhttp.StatusSeeOther {
t.Errorf("status: got %d want 303", res.StatusCode)
}
loc := res.Header.Get("Location")
if !strings.Contains(loc, "oidc_error=no_role_match") {
t.Errorf("location: %q", loc)
}
}
func TestOIDCCallbackUsernameCollision(t *testing.T) {
t.Parallel()
srv, ts, stub := newTestServerWithOIDC(t)
if err := srv.deps.Store.CreateUser(t.Context(), store.User{
ID: "local-alice", Username: "alice", PasswordHash: "x",
Role: store.RoleViewer, CreatedAt: time.Now().UTC(),
}); err != nil {
t.Fatalf("seed: %v", err)
}
res := runCallback(t, ts, stub, map[string]any{
"sub": "remote-sub",
"preferred_username": "alice",
"groups": []string{"rm-admins"},
"aud": "test-client",
})
defer res.Body.Close()
loc := res.Header.Get("Location")
if !strings.Contains(loc, "oidc_error=username_taken") {
t.Errorf("location: %q", loc)
}
if _, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "remote-sub"); err == nil {
t.Error("collision should not have provisioned a user")
}
}
func TestOIDCCallbackReturningUserRefreshesRole(t *testing.T) {
t.Parallel()
srv, ts, stub := newTestServerWithOIDC(t)
res := runCallback(t, ts, stub, map[string]any{
"sub": "carol-sub",
"preferred_username": "carol",
"groups": []string{"rm-operators"},
"aud": "test-client",
})
res.Body.Close()
res = runCallback(t, ts, stub, map[string]any{
"sub": "carol-sub",
"preferred_username": "carol",
"groups": []string{"rm-admins"},
"aud": "test-client",
})
res.Body.Close()
u, _ := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "carol-sub")
if u.Role != "admin" {
t.Errorf("role refresh: got %q want admin", u.Role)
}
}
+1 -1
View File
@@ -146,7 +146,7 @@ func (s *Server) routes(r chi.Router) {
}
if s.deps.OIDC != nil {
r.Get("/auth/oidc/login", s.handleOIDCLogin)
// /auth/oidc/callback registered in D2
r.Get("/auth/oidc/callback", s.handleOIDCCallback)
}
// Viewer band — anyone authenticated can read.