265 lines
8.2 KiB
Go
265 lines
8.2 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
stdhttp "net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc/oidctest"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
|
)
|
|
|
|
// newTestServerWithOIDC returns a Server wired to a stub IdP.
|
|
// Returned ts is the httptest.Server fronting the actual server;
|
|
// stub is the IdP for minting codes / configuring claims.
|
|
func newTestServerWithOIDC(t *testing.T) (*Server, *httptest.Server, *oidctest.StubIdP) {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
|
if err != nil {
|
|
t.Fatalf("store: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = st.Close() })
|
|
|
|
keyPath := filepath.Join(dir, "secret.key")
|
|
if err := crypto.GenerateKeyFile(keyPath); err != nil {
|
|
t.Fatalf("genkey: %v", err)
|
|
}
|
|
key, _ := crypto.LoadKeyFromFile(keyPath)
|
|
aead, _ := crypto.NewAEAD(key)
|
|
|
|
stub := oidctest.New(t)
|
|
cfg := &config.OIDCConfig{
|
|
Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x",
|
|
Scopes: []string{"openid"}, RoleClaim: "groups",
|
|
RoleMapping: map[string]string{
|
|
"rm-admins": "admin",
|
|
"rm-operators": "operator",
|
|
"rm-viewers": "viewer",
|
|
},
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
oidcClient, err := oidc.New(ctx, cfg, "http://test")
|
|
if err != nil {
|
|
t.Fatalf("oidc client: %v", err)
|
|
}
|
|
|
|
deps := Deps{
|
|
Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath, BaseURL: "http://test"},
|
|
Store: st,
|
|
AEAD: aead,
|
|
OIDC: oidcClient,
|
|
}
|
|
s := New(deps)
|
|
ts := httptest.NewServer(s.srv.Handler)
|
|
t.Cleanup(ts.Close)
|
|
return s, ts, stub
|
|
}
|
|
|
|
func TestOIDCLoginRedirectsToIdP(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, _ := newTestServerWithOIDC(t)
|
|
c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
|
return stdhttp.ErrUseLastResponse
|
|
}}
|
|
res, err := c.Get(ts.URL + "/auth/oidc/login")
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusSeeOther {
|
|
t.Errorf("status: got %d want 303", res.StatusCode)
|
|
}
|
|
loc := res.Header.Get("Location")
|
|
if !strings.Contains(loc, "code_challenge=") || !strings.Contains(loc, "state=") {
|
|
t.Errorf("location: %q", loc)
|
|
}
|
|
_ = srv
|
|
}
|
|
|
|
// runCallback drives the auth code flow against the stub: kicks off
|
|
// /auth/oidc/login (capturing the state), mints a code at the stub
|
|
// with the given claims, then GETs /auth/oidc/callback. Returns the
|
|
// final response.
|
|
func runCallback(t *testing.T, ts *httptest.Server, stub *oidctest.StubIdP, claims map[string]any) *stdhttp.Response {
|
|
t.Helper()
|
|
jar, _ := cookiejar.New(nil)
|
|
c := &stdhttp.Client{Jar: jar, CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
|
return stdhttp.ErrUseLastResponse
|
|
}}
|
|
res, err := c.Get(ts.URL + "/auth/oidc/login")
|
|
if err != nil {
|
|
t.Fatalf("login: %v", err)
|
|
}
|
|
res.Body.Close()
|
|
authURL, _ := url.Parse(res.Header.Get("Location"))
|
|
state := authURL.Query().Get("state")
|
|
|
|
code := stub.MintCode(claims)
|
|
res, err = c.Get(ts.URL + "/auth/oidc/callback?code=" + code + "&state=" + state)
|
|
if err != nil {
|
|
t.Fatalf("callback: %v", err)
|
|
}
|
|
return res
|
|
}
|
|
|
|
func TestOIDCCallbackHappyPathAdmin(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, stub := newTestServerWithOIDC(t)
|
|
res := runCallback(t, ts, stub, map[string]any{
|
|
"sub": "admin-sub",
|
|
"preferred_username": "alice",
|
|
"email": "alice@example.com",
|
|
"groups": []string{"rm-admins"},
|
|
"aud": "test-client",
|
|
})
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusSeeOther || res.Header.Get("Location") != "/" {
|
|
t.Errorf("status: %d Location: %q", res.StatusCode, res.Header.Get("Location"))
|
|
}
|
|
u, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "admin-sub")
|
|
if err != nil || u.AuthSource != "oidc" || u.Role != "admin" || u.Username != "alice" {
|
|
t.Errorf("user: %+v err: %v", u, err)
|
|
}
|
|
}
|
|
|
|
func TestOIDCCallbackNoRoleMatchDeny(t *testing.T) {
|
|
t.Parallel()
|
|
_, ts, stub := newTestServerWithOIDC(t)
|
|
res := runCallback(t, ts, stub, map[string]any{
|
|
"sub": "other-sub",
|
|
"preferred_username": "bob",
|
|
"groups": []string{"something-else"},
|
|
"aud": "test-client",
|
|
})
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusSeeOther {
|
|
t.Errorf("status: got %d want 303", res.StatusCode)
|
|
}
|
|
loc := res.Header.Get("Location")
|
|
if !strings.Contains(loc, "oidc_error=no_role_match") {
|
|
t.Errorf("location: %q", loc)
|
|
}
|
|
}
|
|
|
|
func TestOIDCCallbackUsernameCollision(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, stub := newTestServerWithOIDC(t)
|
|
if err := srv.deps.Store.CreateUser(t.Context(), store.User{
|
|
ID: "local-alice", Username: "alice", PasswordHash: "x",
|
|
Role: store.RoleViewer, CreatedAt: time.Now().UTC(),
|
|
}); err != nil {
|
|
t.Fatalf("seed: %v", err)
|
|
}
|
|
|
|
res := runCallback(t, ts, stub, map[string]any{
|
|
"sub": "remote-sub",
|
|
"preferred_username": "alice",
|
|
"groups": []string{"rm-admins"},
|
|
"aud": "test-client",
|
|
})
|
|
defer res.Body.Close()
|
|
loc := res.Header.Get("Location")
|
|
if !strings.Contains(loc, "oidc_error=username_taken") {
|
|
t.Errorf("location: %q", loc)
|
|
}
|
|
if _, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "remote-sub"); err == nil {
|
|
t.Error("collision should not have provisioned a user")
|
|
}
|
|
}
|
|
|
|
func TestOIDCCallbackReturningUserRefreshesRole(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, stub := newTestServerWithOIDC(t)
|
|
res := runCallback(t, ts, stub, map[string]any{
|
|
"sub": "carol-sub",
|
|
"preferred_username": "carol",
|
|
"groups": []string{"rm-operators"},
|
|
"aud": "test-client",
|
|
})
|
|
res.Body.Close()
|
|
res = runCallback(t, ts, stub, map[string]any{
|
|
"sub": "carol-sub",
|
|
"preferred_username": "carol",
|
|
"groups": []string{"rm-admins"},
|
|
"aud": "test-client",
|
|
})
|
|
res.Body.Close()
|
|
u, _ := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "carol-sub")
|
|
if u.Role != "admin" {
|
|
t.Errorf("role refresh: got %q want admin", u.Role)
|
|
}
|
|
}
|
|
|
|
func TestOIDCLogoutRedirectsToEndSession(t *testing.T) {
|
|
t.Parallel()
|
|
srv, ts, stub := newTestServerWithOIDC(t)
|
|
endSessionURL := stub.URL() + "/logout-end"
|
|
stub.SetEndSessionEndpoint(endSessionURL)
|
|
|
|
// Rebuild the OIDC client because end_session_endpoint is read at
|
|
// New() time from the discovery doc.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
cfg := &config.OIDCConfig{
|
|
Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x",
|
|
Scopes: []string{"openid"}, RoleClaim: "groups",
|
|
RoleMapping: map[string]string{"rm-admins": "admin"},
|
|
}
|
|
newClient, err := oidc.New(ctx, cfg, "http://test")
|
|
if err != nil {
|
|
t.Fatalf("rebuild client: %v", err)
|
|
}
|
|
srv.deps.OIDC = newClient
|
|
|
|
// Sign in via the OIDC flow.
|
|
res := runCallback(t, ts, stub, map[string]any{
|
|
"sub": "logout-sub",
|
|
"preferred_username": "lo",
|
|
"groups": []string{"rm-admins"},
|
|
"aud": "test-client",
|
|
})
|
|
res.Body.Close()
|
|
cookies := res.Cookies()
|
|
if len(cookies) == 0 {
|
|
t.Fatal("expected session cookie after sign-in")
|
|
}
|
|
sessionCookie := cookies[0]
|
|
|
|
// POST /logout — should 303 to the end_session endpoint with
|
|
// id_token_hint + post_logout_redirect_uri.
|
|
c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
|
return stdhttp.ErrUseLastResponse
|
|
}}
|
|
req, _ := stdhttp.NewRequest("POST", ts.URL+"/logout", nil)
|
|
req.AddCookie(sessionCookie)
|
|
res, err = c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("logout: %v", err)
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != stdhttp.StatusSeeOther {
|
|
t.Errorf("status: got %d want 303", res.StatusCode)
|
|
}
|
|
loc := res.Header.Get("Location")
|
|
if !strings.Contains(loc, "/logout-end") {
|
|
t.Errorf("location not at end_session: %q", loc)
|
|
}
|
|
if !strings.Contains(loc, "id_token_hint=") {
|
|
t.Errorf("location missing id_token_hint: %q", loc)
|
|
}
|
|
if !strings.Contains(loc, "post_logout_redirect_uri=") {
|
|
t.Errorf("location missing post_logout_redirect_uri: %q", loc)
|
|
}
|
|
}
|