package http import ( "bytes" "encoding/json" stdhttp "net/http" "net/http/httptest" "strings" "testing" "time" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) func TestRoleAtLeast(t *testing.T) { t.Parallel() cases := []struct { have store.Role min store.Role want bool }{ {store.RoleViewer, store.RoleViewer, true}, {store.RoleOperator, store.RoleViewer, true}, {store.RoleAdmin, store.RoleViewer, true}, {store.RoleAdmin, store.RoleOperator, true}, {store.RoleAdmin, store.RoleAdmin, true}, {store.RoleViewer, store.RoleOperator, false}, {store.RoleViewer, store.RoleAdmin, false}, {store.RoleOperator, store.RoleAdmin, false}, {store.Role("nonsense"), store.RoleViewer, false}, {store.RoleAdmin, store.Role("nonsense"), false}, } for _, c := range cases { got := roleAtLeast(c.have, c.min) if got != c.want { t.Errorf("have=%q min=%q: got %v want %v", c.have, c.min, got, c.want) } } } func TestRequireRoleViewerAdmits(t *testing.T) { t.Parallel() srv, _ := newTestServer(t, false) uid := makeUser(t, srv, "viewer1", store.RoleViewer) cookie := loginAs(t, srv, uid) mid := srv.requireRole(store.RoleViewer) h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) { w.WriteHeader(stdhttp.StatusOK) })) rr := httptest.NewRecorder() req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil) req.AddCookie(cookie) h.ServeHTTP(rr, req) if rr.Code != stdhttp.StatusOK { t.Errorf("status: got %d want 200", rr.Code) } } func TestRequireRoleViewerRejectedFromOperator(t *testing.T) { t.Parallel() srv, _ := newTestServer(t, false) uid := makeUser(t, srv, "viewer2", store.RoleViewer) cookie := loginAs(t, srv, uid) mid := srv.requireRole(store.RoleOperator) h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) { w.WriteHeader(stdhttp.StatusOK) })) rr := httptest.NewRecorder() req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil) req.AddCookie(cookie) h.ServeHTTP(rr, req) if rr.Code != stdhttp.StatusForbidden { t.Errorf("status: got %d want 403", rr.Code) } if !strings.Contains(rr.Body.String(), "insufficient_role") { t.Errorf("body: got %q", rr.Body.String()) } } func TestRequireRoleUnauthenticated401OnAPI(t *testing.T) { t.Parallel() srv, _ := newTestServer(t, false) mid := srv.requireRole(store.RoleViewer) h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) { w.WriteHeader(stdhttp.StatusOK) })) rr := httptest.NewRecorder() req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil) h.ServeHTTP(rr, req) if rr.Code != stdhttp.StatusUnauthorized { t.Errorf("status: got %d want 401", rr.Code) } } func TestRequireRoleRejectsDisabledMidSession(t *testing.T) { t.Parallel() srv, urlBase := newTestServer(t, false) uid := makeUser(t, srv, "victim", store.RoleOperator) cookie := loginAs(t, srv, uid) // Disable the user *while their session is still valid*. if err := srv.deps.Store.DisableUser(t.Context(), uid, time.Now().UTC()); err != nil { t.Fatalf("disable: %v", err) } req, _ := stdhttp.NewRequest("GET", urlBase+"/api/hosts", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("status: got %d want 401", res.StatusCode) } } func TestLoginRejectsDisabledUser(t *testing.T) { t.Parallel() srv, urlBase := newTestServer(t, false) uid := makeUser(t, srv, "disabled1", store.RoleOperator) if err := srv.deps.Store.DisableUser(t.Context(), uid, time.Now().UTC()); err != nil { t.Fatalf("disable: %v", err) } body, _ := json.Marshal(map[string]string{ "username": "disabled1", "password": "test-password", }) res, err := stdhttp.Post(urlBase+"/api/auth/login", "application/json", bytes.NewReader(body)) if err != nil { t.Fatalf("POST: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("status: got %d want 401", res.StatusCode) } } func TestAdminBandRejectsOperator(t *testing.T) { t.Parallel() srv, urlBase := newTestServer(t, false) makeUser(t, srv, "admin1", store.RoleAdmin) opID := makeUser(t, srv, "op1", store.RoleOperator) cookie := loginAs(t, srv, opID) req, _ := stdhttp.NewRequest("GET", urlBase+"/api/users", nil) req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusForbidden { t.Errorf("status: got %d want 403", res.StatusCode) } }