package http import ( "context" stdhttp "net/http" "net/http/httptest" "path/filepath" "strings" "testing" "time" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc/oidctest" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // newTestServerWithOIDC returns a Server wired to a stub IdP. // Returned ts is the httptest.Server fronting the actual server; // stub is the IdP for minting codes / configuring claims. func newTestServerWithOIDC(t *testing.T) (*Server, *httptest.Server, *oidctest.StubIdP) { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) keyPath := filepath.Join(dir, "secret.key") if err := crypto.GenerateKeyFile(keyPath); err != nil { t.Fatalf("genkey: %v", err) } key, _ := crypto.LoadKeyFromFile(keyPath) aead, _ := crypto.NewAEAD(key) stub := oidctest.New(t) cfg := &config.OIDCConfig{ Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x", Scopes: []string{"openid"}, RoleClaim: "groups", RoleMapping: map[string]string{ "rm-admins": "admin", "rm-operators": "operator", "rm-viewers": "viewer", }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() oidcClient, err := oidc.New(ctx, cfg, "http://test") if err != nil { t.Fatalf("oidc client: %v", err) } deps := Deps{ Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath, BaseURL: "http://test"}, Store: st, AEAD: aead, OIDC: oidcClient, } s := New(deps) ts := httptest.NewServer(s.srv.Handler) t.Cleanup(ts.Close) return s, ts, stub } func TestOIDCLoginRedirectsToIdP(t *testing.T) { t.Parallel() srv, ts, _ := newTestServerWithOIDC(t) c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error { return stdhttp.ErrUseLastResponse }} res, err := c.Get(ts.URL + "/auth/oidc/login") if err != nil { t.Fatalf("get: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusSeeOther { t.Errorf("status: got %d want 303", res.StatusCode) } loc := res.Header.Get("Location") if !strings.Contains(loc, "code_challenge=") || !strings.Contains(loc, "state=") { t.Errorf("location: %q", loc) } _ = srv }