package http import ( "bytes" "context" "encoding/json" stdhttp "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "path/filepath" "strings" "testing" "time" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc/oidctest" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // newTestServerWithOIDC returns a Server wired to a stub IdP. // Returned ts is the httptest.Server fronting the actual server; // stub is the IdP for minting codes / configuring claims. func newTestServerWithOIDC(t *testing.T) (*Server, *httptest.Server, *oidctest.StubIdP) { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) keyPath := filepath.Join(dir, "secret.key") if err := crypto.GenerateKeyFile(keyPath); err != nil { t.Fatalf("genkey: %v", err) } key, _ := crypto.LoadKeyFromFile(keyPath) aead, _ := crypto.NewAEAD(key) stub := oidctest.New(t) cfg := &config.OIDCConfig{ Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x", Scopes: []string{"openid"}, RoleClaim: "groups", RoleMapping: map[string]string{ "rm-admins": "admin", "rm-operators": "operator", "rm-viewers": "viewer", }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() oidcClient, err := oidc.New(ctx, cfg, "http://test") if err != nil { t.Fatalf("oidc client: %v", err) } deps := Deps{ Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath, BaseURL: "http://test"}, Store: st, AEAD: aead, OIDC: oidcClient, } s := New(deps) ts := httptest.NewServer(s.srv.Handler) t.Cleanup(ts.Close) return s, ts, stub } func TestOIDCLoginRedirectsToIdP(t *testing.T) { t.Parallel() srv, ts, _ := newTestServerWithOIDC(t) c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { return stdhttp.ErrUseLastResponse }} res, err := c.Get(ts.URL + "/auth/oidc/login") if err != nil { t.Fatalf("get: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther { t.Errorf("status: got %d want 303", res.StatusCode) } loc := res.Header.Get("Location") if !strings.Contains(loc, "code_challenge=") || !strings.Contains(loc, "state=") { t.Errorf("location: %q", loc) } _ = srv } // runCallback drives the auth code flow against the stub: kicks off // /auth/oidc/login (capturing the state), mints a code at the stub // with the given claims, then GETs /auth/oidc/callback. Returns the // final response. func runCallback(t *testing.T, ts *httptest.Server, stub *oidctest.StubIdP, claims map[string]any) *stdhttp.Response { t.Helper() jar, _ := cookiejar.New(nil) c := &stdhttp.Client{Jar: jar, CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { return stdhttp.ErrUseLastResponse }} res, err := c.Get(ts.URL + "/auth/oidc/login") if err != nil { t.Fatalf("login: %v", err) } res.Body.Close() authURL, _ := url.Parse(res.Header.Get("Location")) state := authURL.Query().Get("state") code := stub.MintCode(claims) res, err = c.Get(ts.URL + "/auth/oidc/callback?code=" + code + "&state=" + state) if err != nil { t.Fatalf("callback: %v", err) } return res } func TestOIDCCallbackHappyPathAdmin(t *testing.T) { t.Parallel() srv, ts, stub := newTestServerWithOIDC(t) res := runCallback(t, ts, stub, map[string]any{ "sub": "admin-sub", "preferred_username": "alice", "email": "alice@example.com", "groups": []string{"rm-admins"}, "aud": "test-client", }) defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther || res.Header.Get("Location") != "/" { t.Errorf("status: %d Location: %q", res.StatusCode, res.Header.Get("Location")) } u, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "admin-sub") if err != nil || u.AuthSource != "oidc" || u.Role != "admin" || u.Username != "alice" { t.Errorf("user: %+v err: %v", u, err) } } func TestOIDCCallbackNoRoleMatchDeny(t *testing.T) { t.Parallel() _, ts, stub := newTestServerWithOIDC(t) res := runCallback(t, ts, stub, map[string]any{ "sub": "other-sub", "preferred_username": "bob", "groups": []string{"something-else"}, "aud": "test-client", }) defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther { t.Errorf("status: got %d want 303", res.StatusCode) } loc := res.Header.Get("Location") if !strings.Contains(loc, "oidc_error=no_role_match") { t.Errorf("location: %q", loc) } } func TestOIDCCallbackUsernameCollision(t *testing.T) { t.Parallel() srv, ts, stub := newTestServerWithOIDC(t) if err := srv.deps.Store.CreateUser(t.Context(), store.User{ ID: "local-alice", Username: "alice", PasswordHash: "x", Role: store.RoleViewer, CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed: %v", err) } res := runCallback(t, ts, stub, map[string]any{ "sub": "remote-sub", "preferred_username": "alice", "groups": []string{"rm-admins"}, "aud": "test-client", }) defer res.Body.Close() loc := res.Header.Get("Location") if !strings.Contains(loc, "oidc_error=username_taken") { t.Errorf("location: %q", loc) } if _, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "remote-sub"); err == nil { t.Error("collision should not have provisioned a user") } } func TestOIDCCallbackReturningUserRefreshesRole(t *testing.T) { t.Parallel() srv, ts, stub := newTestServerWithOIDC(t) res := runCallback(t, ts, stub, map[string]any{ "sub": "carol-sub", "preferred_username": "carol", "groups": []string{"rm-operators"}, "aud": "test-client", }) res.Body.Close() res = runCallback(t, ts, stub, map[string]any{ "sub": "carol-sub", "preferred_username": "carol", "groups": []string{"rm-admins"}, "aud": "test-client", }) res.Body.Close() u, _ := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "carol-sub") if u.Role != "admin" { t.Errorf("role refresh: got %q want admin", u.Role) } } func TestOIDCLogoutRedirectsToEndSession(t *testing.T) { t.Parallel() srv, ts, stub := newTestServerWithOIDC(t) endSessionURL := stub.URL() + "/logout-end" stub.SetEndSessionEndpoint(endSessionURL) // Rebuild the OIDC client because end_session_endpoint is read at // New() time from the discovery doc. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() cfg := &config.OIDCConfig{ Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x", Scopes: []string{"openid"}, RoleClaim: "groups", RoleMapping: map[string]string{"rm-admins": "admin"}, } newClient, err := oidc.New(ctx, cfg, "http://test") if err != nil { t.Fatalf("rebuild client: %v", err) } srv.deps.OIDC = newClient // Sign in via the OIDC flow. res := runCallback(t, ts, stub, map[string]any{ "sub": "logout-sub", "preferred_username": "lo", "groups": []string{"rm-admins"}, "aud": "test-client", }) res.Body.Close() cookies := res.Cookies() if len(cookies) == 0 { t.Fatal("expected session cookie after sign-in") } sessionCookie := cookies[0] // POST /logout — should 303 to the end_session endpoint with // id_token_hint + post_logout_redirect_uri. c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { return stdhttp.ErrUseLastResponse }} req, _ := stdhttp.NewRequest("POST", ts.URL+"/logout", nil) req.AddCookie(sessionCookie) res, err = c.Do(req) if err != nil { t.Fatalf("logout: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther { t.Errorf("status: got %d want 303", res.StatusCode) } loc := res.Header.Get("Location") if !strings.Contains(loc, "/logout-end") { t.Errorf("location not at end_session: %q", loc) } if !strings.Contains(loc, "id_token_hint=") { t.Errorf("location missing id_token_hint: %q", loc) } if !strings.Contains(loc, "post_logout_redirect_uri=") { t.Errorf("location missing post_logout_redirect_uri: %q", loc) } } func TestLocalLoginRejectsOIDCUser(t *testing.T) { t.Parallel() srv, urlBase := newTestServer(t, false) uid := "u-oidc" sub := "sub-x" if err := srv.deps.Store.CreateUser(t.Context(), store.User{ ID: uid, Username: "ouser", PasswordHash: "", Role: store.RoleOperator, CreatedAt: time.Now().UTC(), AuthSource: "oidc", OIDCSubject: &sub, }); err != nil { t.Fatalf("create: %v", err) } body, _ := json.Marshal(map[string]string{ "username": "ouser", "password": "anything", }) res, err := stdhttp.Post(urlBase+"/api/auth/login", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("post: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("status: got %d want 401", res.StatusCode) } }