package http import ( "context" "io" stdhttp "net/http" "net/http/httptest" "path/filepath" "strings" "testing" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/metrics" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // newMetricsServer builds a Server with metrics enabled per cfg. // Returns (URL, registry) so tests can both observe job durations // directly and exercise the HTTP gate. func newMetricsServer(t *testing.T, cfg config.Config) (string, *metrics.Registry, *store.Store) { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) keyPath := filepath.Join(dir, "secret.key") if err := crypto.GenerateKeyFile(keyPath); err != nil { t.Fatalf("genkey: %v", err) } key, _ := crypto.LoadKeyFromFile(keyPath) aead, _ := crypto.NewAEAD(key) cfg.Listen = ":0" cfg.DataDir = dir cfg.SecretKeyFile = keyPath reg := metrics.NewRegistry() deps := Deps{ Cfg: cfg, Store: st, AEAD: aead, Metrics: reg, } s := New(deps) ts := httptest.NewServer(s.srv.Handler) t.Cleanup(ts.Close) return ts.URL, reg, st } func TestMetricsRouteNotMountedByDefault(t *testing.T) { t.Parallel() url, _, _ := newMetricsServer(t, config.Config{}) res, err := stdhttp.Get(url + "/metrics") if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusNotFound { t.Errorf("status: got %d, want 404 (route should not be mounted)", res.StatusCode) } } func TestMetricsTokenRequired(t *testing.T) { t.Parallel() url, _, _ := newMetricsServer(t, config.Config{ MetricsToken: "the-token", }) // Missing token. res, err := stdhttp.Get(url + "/metrics") if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("no token: got %d", res.StatusCode) } if !strings.Contains(res.Header.Get("WWW-Authenticate"), "Bearer") { t.Errorf("WWW-Authenticate hint missing: %q", res.Header.Get("WWW-Authenticate")) } // Wrong token. req, _ := stdhttp.NewRequest(stdhttp.MethodGet, url+"/metrics", nil) req.Header.Set("Authorization", "Bearer not-the-token") res2, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer res2.Body.Close() if res2.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("wrong token: got %d", res2.StatusCode) } // Right token. req3, _ := stdhttp.NewRequest(stdhttp.MethodGet, url+"/metrics", nil) req3.Header.Set("Authorization", "Bearer the-token") res3, err3 := stdhttp.DefaultClient.Do(req3) if err3 != nil { t.Fatalf("GET: %v", err3) } defer res3.Body.Close() if res3.StatusCode != stdhttp.StatusOK { t.Errorf("right token: got %d", res3.StatusCode) } if ct := res3.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { t.Errorf("content-type: %q", ct) } } func TestMetricsCIDRGate(t *testing.T) { t.Parallel() // 127.0.0.1 is what httptest hits with; pick a CIDR that excludes it // to assert the "wrong source" branch. url, _, _ := newMetricsServer(t, config.Config{ MetricsTrustedCIDRs: []string{"10.0.0.0/8"}, }) res, err := stdhttp.Get(url + "/metrics") if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("loopback hitting non-matching CIDR: got %d, want 401", res.StatusCode) } // Now allow loopback. url2, _, _ := newMetricsServer(t, config.Config{ MetricsTrustedCIDRs: []string{"127.0.0.0/8"}, }) res2, err := stdhttp.Get(url2 + "/metrics") if err != nil { t.Fatalf("GET: %v", err) } defer res2.Body.Close() if res2.StatusCode != stdhttp.StatusOK { t.Errorf("loopback in allow CIDR: got %d, want 200", res2.StatusCode) } } func TestMetricsTokenAndCIDRBothRequired(t *testing.T) { t.Parallel() url, _, _ := newMetricsServer(t, config.Config{ MetricsToken: "the-token", MetricsTrustedCIDRs: []string{"127.0.0.0/8"}, }) // Token only — CIDR ok (loopback) but token missing. res, err := stdhttp.Get(url + "/metrics") if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnauthorized { t.Errorf("missing token but in CIDR: got %d", res.StatusCode) } // Both right. req, _ := stdhttp.NewRequest(stdhttp.MethodGet, url+"/metrics", nil) req.Header.Set("Authorization", "Bearer the-token") res2, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer res2.Body.Close() if res2.StatusCode != stdhttp.StatusOK { t.Errorf("both right: got %d", res2.StatusCode) } } func readAll(t *testing.T, r io.Reader) string { t.Helper() b, err := io.ReadAll(r) if err != nil { t.Fatalf("read: %v", err) } return string(b) } func TestMetricsBodyContainsExpectedLines(t *testing.T) { t.Parallel() url, reg, _ := newMetricsServer(t, config.Config{ MetricsToken: "the-token", }) reg.ObserveJob("backup", "succeeded", 0) // produce one histogram row req, _ := stdhttp.NewRequest(stdhttp.MethodGet, url+"/metrics", nil) req.Header.Set("Authorization", "Bearer the-token") res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer res.Body.Close() body := readAll(t, res.Body) for _, want := range []string{ "rm_hosts_total", "rm_hosts_online", `rm_active_alerts{severity="critical"}`, "rm_build_info{", "rm_job_duration_seconds_count{kind=\"backup\",status=\"succeeded\"}", } { if !strings.Contains(body, want) { t.Errorf("body missing %q\n--- body ---\n%s", want, body) } } }