From 1fd9dce8a2cfd09a1b0dad075ab5c9f8d768965e Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Tue, 5 May 2026 13:30:00 +0100 Subject: [PATCH] =?UTF-8?q?http:=20GET=20/auth/oidc/callback=20=E2=80=94?= =?UTF-8?q?=20JIT-provision,=20refresh,=20deny=20paths?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/http/oidc_handlers.go | 170 +++++++++++++++++++++ internal/server/http/oidc_handlers_test.go | 116 ++++++++++++++ internal/server/http/server.go | 2 +- 3 files changed, 287 insertions(+), 1 deletion(-) diff --git a/internal/server/http/oidc_handlers.go b/internal/server/http/oidc_handlers.go index b29bf22..45763c8 100644 --- a/internal/server/http/oidc_handlers.go +++ b/internal/server/http/oidc_handlers.go @@ -3,11 +3,18 @@ package http import ( + "encoding/json" + "errors" "log/slog" stdhttp "net/http" + "strings" "time" + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // handleOIDCLogin generates state + PKCE pair, persists them, and @@ -33,3 +40,166 @@ func (s *Server) handleOIDCLogin(w stdhttp.ResponseWriter, r *stdhttp.Request) { } stdhttp.Redirect(w, r, s.deps.OIDC.AuthURL(state, challenge), stdhttp.StatusSeeOther) } + +func (s *Server) handleOIDCCallback(w stdhttp.ResponseWriter, r *stdhttp.Request) { + q := r.URL.Query() + code := q.Get("code") + state := q.Get("state") + if code == "" || state == "" { + s.oidcRedirectError(w, r, "missing_params") + return + } + verifier, err := s.deps.Store.ConsumeOIDCState(r.Context(), oidc.HashState(state)) + if err != nil { + s.oidcRedirectError(w, r, "bad_state") + return + } + claims, rawIDToken, err := s.deps.OIDC.Exchange(r.Context(), code, verifier) + if err != nil { + slog.Warn("oidc callback: exchange", "err", err) + s.oidcRedirectError(w, r, "exchange_failed") + return + } + + uname := strings.ToLower(strings.TrimSpace(claims.PreferredUsername)) + if uname == "" { + uname = strings.ToLower(strings.TrimSpace(claims.Email)) + } + if uname == "" || claims.Subject == "" { + s.oidcRedirectError(w, r, "missing_claims") + return + } + + role := s.deps.OIDC.MapRole(claims.Roles) + if role == "" { + _ = s.auditOIDCBlocked(r, claims, "no_role_match") + s.oidcRedirectError(w, r, "no_role_match") + return + } + + now := time.Now().UTC() + + // Returning OIDC user — refresh role + email + last_login. + existing, err := s.deps.Store.GetUserByOIDCSubject(r.Context(), claims.Subject) + if err == nil { + if existing.DisabledAt != nil { + s.oidcRedirectError(w, r, "user_disabled") + return + } + _ = s.deps.Store.SetUserRole(r.Context(), existing.ID, store.Role(role)) + _ = s.deps.Store.SetUserEmail(r.Context(), existing.ID, claims.Email) + _ = s.deps.Store.MarkUserLogin(r.Context(), existing.ID, now) + _ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), UserID: &existing.ID, Actor: "user", + Action: "user.oidc_login", TargetKind: ptr("user"), + TargetID: &existing.ID, TS: now, + }) + s.oidcDropSessionAndRedirect(w, r, existing.ID, rawIDToken, now) + return + } else if !errors.Is(err, store.ErrNotFound) { + slog.Error("oidc callback: lookup by sub", "err", err) + stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) + return + } + + // New OIDC user — first check the username doesn't collide with + // a local user. + if _, err := s.deps.Store.GetUserByUsername(r.Context(), uname); err == nil { + _ = s.auditOIDCBlocked(r, claims, "username_taken") + s.oidcRedirectError(w, r, "username_taken") + return + } else if !errors.Is(err, store.ErrNotFound) { + slog.Error("oidc callback: lookup by username", "err", err) + stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) + return + } + + // JIT-provision. + id := ulid.Make().String() + var emailPtr *string + if claims.Email != "" { + em := strings.ToLower(claims.Email) + emailPtr = &em + } + sub := claims.Subject + if err := s.deps.Store.CreateUser(r.Context(), store.User{ + ID: id, Username: uname, PasswordHash: "", + Role: store.Role(role), Email: emailPtr, + AuthSource: "oidc", OIDCSubject: &sub, + CreatedAt: now, + }); err != nil { + slog.Error("oidc callback: provision", "err", err) + stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) + return + } + _ = s.deps.Store.MarkUserLogin(r.Context(), id, now) + _ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), UserID: &id, Actor: "user", + Action: "user.created", TargetKind: ptr("user"), TargetID: &id, + TS: now, + Payload: jsonMust(map[string]any{"auth_source": "oidc"}), + }) + _ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), UserID: &id, Actor: "user", + Action: "user.oidc_login", TargetKind: ptr("user"), TargetID: &id, + TS: now, + }) + s.oidcDropSessionAndRedirect(w, r, id, rawIDToken, now) +} + +func (s *Server) oidcDropSessionAndRedirect(w stdhttp.ResponseWriter, r *stdhttp.Request, userID, idToken string, now time.Time) { + rawSession, err := auth.NewToken() + if err != nil { + slog.Error("oidc: session token", "err", err) + stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) + return + } + hashed := auth.HashToken(rawSession) + if err := s.deps.Store.CreateSession(r.Context(), store.Session{ + ID: hashed, UserID: userID, CreatedAt: now, + ExpiresAt: now.Add(8 * time.Hour), + IDToken: idToken, + }, hashed); err != nil { + slog.Error("oidc: create session", "err", err) + stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) + return + } + stdhttp.SetCookie(w, &stdhttp.Cookie{ + Name: sessionCookieName, Value: rawSession, + Path: "/", HttpOnly: true, + SameSite: stdhttp.SameSiteLaxMode, + Secure: s.deps.Cfg.CookieSecure, + Expires: now.Add(8 * time.Hour), + }) + stdhttp.Redirect(w, r, "/", stdhttp.StatusSeeOther) +} + +func (s *Server) oidcRedirectError(w stdhttp.ResponseWriter, r *stdhttp.Request, code string) { + stdhttp.Redirect(w, r, "/login?oidc_error="+code, stdhttp.StatusSeeOther) +} + +// auditOIDCBlocked records a failed sign-in. user_id is nil because +// no row was created; the IdP subject + reason go in the payload so +// admin can correlate. +func (s *Server) auditOIDCBlocked(r *stdhttp.Request, claims *oidc.Claims, reason string) error { + return s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), UserID: nil, Actor: "system", + Action: "user.oidc_login_blocked", TargetKind: ptr("user"), + TargetID: nil, TS: time.Now().UTC(), + Payload: jsonMust(map[string]any{ + "sub": claims.Subject, + "username": claims.PreferredUsername, + "reason": reason, + }), + }) +} + +// jsonMust marshals to json.RawMessage; on error returns nil so the +// audit row still lands without the payload (best-effort). +func jsonMust(v any) json.RawMessage { + b, err := json.Marshal(v) + if err != nil { + return nil + } + return json.RawMessage(b) +} diff --git a/internal/server/http/oidc_handlers_test.go b/internal/server/http/oidc_handlers_test.go index 7fb40ab..7b48c20 100644 --- a/internal/server/http/oidc_handlers_test.go +++ b/internal/server/http/oidc_handlers_test.go @@ -3,7 +3,9 @@ package http import ( "context" stdhttp "net/http" + "net/http/cookiejar" "net/http/httptest" + "net/url" "path/filepath" "strings" "testing" @@ -84,3 +86,117 @@ func TestOIDCLoginRedirectsToIdP(t *testing.T) { } _ = srv } + +// runCallback drives the auth code flow against the stub: kicks off +// /auth/oidc/login (capturing the state), mints a code at the stub +// with the given claims, then GETs /auth/oidc/callback. Returns the +// final response. +func runCallback(t *testing.T, ts *httptest.Server, stub *oidctest.StubIdP, claims map[string]any) *stdhttp.Response { + t.Helper() + jar, _ := cookiejar.New(nil) + c := &stdhttp.Client{Jar: jar, CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { + return stdhttp.ErrUseLastResponse + }} + res, err := c.Get(ts.URL + "/auth/oidc/login") + if err != nil { + t.Fatalf("login: %v", err) + } + res.Body.Close() + authURL, _ := url.Parse(res.Header.Get("Location")) + state := authURL.Query().Get("state") + + code := stub.MintCode(claims) + res, err = c.Get(ts.URL + "/auth/oidc/callback?code=" + code + "&state=" + state) + if err != nil { + t.Fatalf("callback: %v", err) + } + return res +} + +func TestOIDCCallbackHappyPathAdmin(t *testing.T) { + t.Parallel() + srv, ts, stub := newTestServerWithOIDC(t) + res := runCallback(t, ts, stub, map[string]any{ + "sub": "admin-sub", + "preferred_username": "alice", + "email": "alice@example.com", + "groups": []string{"rm-admins"}, + "aud": "test-client", + }) + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusSeeOther || res.Header.Get("Location") != "/" { + t.Errorf("status: %d Location: %q", res.StatusCode, res.Header.Get("Location")) + } + u, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "admin-sub") + if err != nil || u.AuthSource != "oidc" || u.Role != "admin" || u.Username != "alice" { + t.Errorf("user: %+v err: %v", u, err) + } +} + +func TestOIDCCallbackNoRoleMatchDeny(t *testing.T) { + t.Parallel() + _, ts, stub := newTestServerWithOIDC(t) + res := runCallback(t, ts, stub, map[string]any{ + "sub": "other-sub", + "preferred_username": "bob", + "groups": []string{"something-else"}, + "aud": "test-client", + }) + defer res.Body.Close() + if res.StatusCode != stdhttp.StatusSeeOther { + t.Errorf("status: got %d want 303", res.StatusCode) + } + loc := res.Header.Get("Location") + if !strings.Contains(loc, "oidc_error=no_role_match") { + t.Errorf("location: %q", loc) + } +} + +func TestOIDCCallbackUsernameCollision(t *testing.T) { + t.Parallel() + srv, ts, stub := newTestServerWithOIDC(t) + if err := srv.deps.Store.CreateUser(t.Context(), store.User{ + ID: "local-alice", Username: "alice", PasswordHash: "x", + Role: store.RoleViewer, CreatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("seed: %v", err) + } + + res := runCallback(t, ts, stub, map[string]any{ + "sub": "remote-sub", + "preferred_username": "alice", + "groups": []string{"rm-admins"}, + "aud": "test-client", + }) + defer res.Body.Close() + loc := res.Header.Get("Location") + if !strings.Contains(loc, "oidc_error=username_taken") { + t.Errorf("location: %q", loc) + } + if _, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "remote-sub"); err == nil { + t.Error("collision should not have provisioned a user") + } +} + +func TestOIDCCallbackReturningUserRefreshesRole(t *testing.T) { + t.Parallel() + srv, ts, stub := newTestServerWithOIDC(t) + res := runCallback(t, ts, stub, map[string]any{ + "sub": "carol-sub", + "preferred_username": "carol", + "groups": []string{"rm-operators"}, + "aud": "test-client", + }) + res.Body.Close() + res = runCallback(t, ts, stub, map[string]any{ + "sub": "carol-sub", + "preferred_username": "carol", + "groups": []string{"rm-admins"}, + "aud": "test-client", + }) + res.Body.Close() + u, _ := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "carol-sub") + if u.Role != "admin" { + t.Errorf("role refresh: got %q want admin", u.Role) + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index 326ad3b..a8eddd5 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -146,7 +146,7 @@ func (s *Server) routes(r chi.Router) { } if s.deps.OIDC != nil { r.Get("/auth/oidc/login", s.handleOIDCLogin) - // /auth/oidc/callback registered in D2 + r.Get("/auth/oidc/callback", s.handleOIDCCallback) } // Viewer band — anyone authenticated can read.