// ui_repo_reinit_test.go — covers the danger-zone re-init handler: // hostname-confirm gate + offline guard + missing-creds guard. package http import ( "context" stdhttp "net/http" "net/http/httptest" "net/url" "path/filepath" "strings" "testing" "time" "github.com/coder/websocket" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/config" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ui" "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // rawTestServerWithUI is the rawTestServer twin that also wires the // UI renderer in, returning the raw httptest server so callers can // dial /ws/agent. The UI is needed for the repo-reinit handler's // error re-render path. func rawTestServerWithUI(t *testing.T) (*Server, *httptest.Server, *store.Store) { t.Helper() dir := t.TempDir() st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) if err != nil { t.Fatalf("store: %v", err) } t.Cleanup(func() { _ = st.Close() }) keyPath := filepath.Join(dir, "secret.key") _ = crypto.GenerateKeyFile(keyPath) key, _ := crypto.LoadKeyFromFile(keyPath) aead, _ := crypto.NewAEAD(key) renderer, err := ui.New() if err != nil { t.Fatalf("ui.New: %v", err) } deps := Deps{ Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath}, Store: st, AEAD: aead, Hub: ws.NewHub(), UI: renderer, } srv := New(deps) ts := httptest.NewServer(srv.srv.Handler) t.Cleanup(ts.Close) return srv, ts, st } // enrolHostForUI is the enrolHostForWS twin for tests that use the // UI-enabled rawTestServerWithUI. func enrolHostForUI(t *testing.T, _ *Server, st *store.Store, name string) (hostID, token string) { t.Helper() hostID = ulid.Make().String() token, _ = auth.NewToken() if err := st.CreateHost(context.Background(), store.Host{ ID: hostID, Name: name, OS: "linux", Arch: "amd64", EnrolledAt: time.Now().UTC(), }, auth.HashToken(token), ""); err != nil { t.Fatalf("create host: %v", err) } return hostID, token } // TestRepoReinitWrongHostnameRejected: typing a different name keeps // the page on the repo screen with an error banner; no init job is // dispatched. func TestRepoReinitWrongHostnameRejected(t *testing.T) { t.Parallel() srv, ts, st := rawTestServerWithUI(t) hostID, token := enrolHostForUI(t, srv, st, "reinit-host") c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "reinit-host") _ = drainUntil(t, c, api.MsgScheduleSet) cookie := loginAsAdmin(t, st) form := url.Values{"confirm_hostname": {"WRONG-NAME"}} req, _ := stdhttp.NewRequest("POST", ts.URL+"/hosts/"+hostID+"/repo/reinit", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusUnprocessableEntity { t.Fatalf("status: got %d, want 422 (re-rendered page with banner)", res.StatusCode) } // No init job should appear in the queue beyond the one auto-init // pushed on hello (which fires when no init has run yet — let's // just make sure no new "user" actor init was created). var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'init' AND actor_kind = 'user'`, hostID).Scan(&n); err != nil { t.Fatalf("count: %v", err) } if n != 0 { t.Fatalf("user-actor init jobs: got %d, want 0 (gate was bypassed)", n) } } // TestRepoReinitDispatchesOnMatch: typing the right hostname dispatches // a new init job + audit row. func TestRepoReinitDispatchesOnMatch(t *testing.T) { t.Parallel() srv, ts, st := rawTestServerWithUI(t) hostID, token := enrolHostForUI(t, srv, st, "reinit-ok-host") // Bind repo creds — re-init guard requires them. enc, err := srv.encryptRepoCreds(repoCredsBlob{ RepoURL: "rest:http://r/x", RepoUsername: "u", RepoPassword: "p", }, []byte("host:"+hostID)) if err != nil { t.Fatalf("encrypt: %v", err) } if err := st.SetHostCredentials(context.Background(), hostID, store.CredKindRepo, enc); err != nil { t.Fatalf("set creds: %v", err) } // Pre-seed a successful init so auto-init doesn't fire on hello. preID := ulid.Make().String() if err := st.CreateJob(context.Background(), store.Job{ ID: preID, HostID: hostID, Kind: "init", ActorKind: "system", CreatedAt: time.Now().UTC(), }); err != nil { t.Fatalf("seed init: %v", err) } if err := st.MarkJobFinished(context.Background(), preID, "succeeded", 0, nil, "", time.Now().UTC()); err != nil { t.Fatalf("mark seed init: %v", err) } c := agentDial(t, srv, ts, hostID, token) sendHello(t, c, "reinit-ok-host") _ = drainUntil(t, c, api.MsgScheduleSet) cookie := loginAsAdmin(t, st) form := url.Values{"confirm_hostname": {"reinit-ok-host"}} req, _ := stdhttp.NewRequest("POST", ts.URL+"/hosts/"+hostID+"/repo/reinit", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("HX-Request", "true") // get HX-Redirect path req.AddCookie(cookie) res, err := stdhttp.DefaultClient.Do(req) if err != nil { t.Fatalf("do: %v", err) } defer res.Body.Close() if res.StatusCode != stdhttp.StatusNoContent { t.Fatalf("status: got %d, want 204", res.StatusCode) } if res.Header.Get("HX-Redirect") == "" { t.Fatal("expected HX-Redirect header") } // Read the dispatched command.run; assert it's an init job. deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) mt, raw, rerr := c.Read(ctx) cancel() if rerr != nil { break } if mt != websocket.MessageText { continue } // Quick parse — we only care about the type. Avoid full // envelope unmarshal here because the surrounding loop is just // looking for the command.run we triggered. if !strings.Contains(string(raw), `"command.run"`) { continue } // Verify a user-actor init job row was created. var n int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM jobs WHERE host_id = ? AND kind = 'init' AND actor_kind = 'user'`, hostID).Scan(&n); err != nil { t.Fatalf("count: %v", err) } if n != 1 { t.Fatalf("user-actor init jobs: got %d, want 1", n) } // Audit row. var na int if err := st.DB().QueryRow( `SELECT COUNT(*) FROM audit_log WHERE action = 'host.repo_reinit' AND target_id = ?`, hostID).Scan(&na); err != nil { t.Fatalf("audit count: %v", err) } if na != 1 { t.Fatalf("audit rows: got %d, want 1", na) } return } t.Fatal("timed out waiting for command.run after re-init dispatch") }