From de6d51eeb1e93b2c964e6e2f06b168d37d0fc309 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Sun, 3 May 2026 22:06:05 +0100 Subject: [PATCH] store: host_credentials becomes kind-aware (repo + admin slots) --- internal/server/http/enrollment.go | 2 +- internal/server/http/host_credentials.go | 10 +- internal/server/http/host_credentials_test.go | 6 +- internal/server/http/p2r01_ws_test.go | 2 +- internal/server/http/ui_repo.go | 6 +- internal/store/host_credentials.go | 45 ++++++-- internal/store/host_credentials_test.go | 103 ++++++++++++++++++ 7 files changed, 151 insertions(+), 23 deletions(-) create mode 100644 internal/store/host_credentials_test.go diff --git a/internal/server/http/enrollment.go b/internal/server/http/enrollment.go index 2706ea5..f1615e0 100644 --- a/internal/server/http/enrollment.go +++ b/internal/server/http/enrollment.go @@ -167,7 +167,7 @@ func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request) // /api/hosts/{id}/repo-credentials. Failing the whole enrolment // here would leave a half-burned token + an orphan host. if encForHost != "" { - if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, encForHost); err != nil { + if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, encForHost); err != nil { slog.Error("enrollment: set host credentials failed", "host_id", hostID, "err", err) } diff --git a/internal/server/http/host_credentials.go b/internal/server/http/host_credentials.go index 5887a75..93bd1d0 100644 --- a/internal/server/http/host_credentials.go +++ b/internal/server/http/host_credentials.go @@ -39,7 +39,7 @@ func (s *Server) handleGetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R writeJSONError(w, stdhttp.StatusBadRequest, "missing_id", "") return } - enc, err := s.deps.Store.GetHostCredentials(r.Context(), hostID) + enc, err := s.deps.Store.GetHostCredentials(r.Context(), hostID, store.CredKindRepo) if err != nil { if errors.Is(err, store.ErrNotFound) { writeJSONError(w, stdhttp.StatusNotFound, "not_set", "") @@ -107,7 +107,7 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R // Merge with the existing row, if any. existing := repoCredsBlob{} - if cur, err := s.deps.Store.GetHostCredentials(r.Context(), hostID); err == nil { + if cur, err := s.deps.Store.GetHostCredentials(r.Context(), hostID, store.CredKindRepo); err == nil { plain, err := s.deps.AEAD.Decrypt(cur, []byte("host:"+hostID)) if err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "decrypt_failed", "") @@ -139,7 +139,7 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") return } - if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, enc); err != nil { + if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, enc); err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") return } @@ -212,7 +212,7 @@ func (s *Server) onAgentHello(ctx context.Context, hostID string, conn *ws.Conn) // them the runner can't talk to the repo). We rely on Restic's // idempotent init for re-runs. func (s *Server) maybeAutoInit(ctx context.Context, hostID string, conn *ws.Conn) { - if _, err := s.deps.Store.GetHostCredentials(ctx, hostID); err != nil { + if _, err := s.deps.Store.GetHostCredentials(ctx, hostID, store.CredKindRepo); err != nil { // No creds bound yet — operator hasn't supplied them. The next // hello after creds land will pick this up. return @@ -266,7 +266,7 @@ func (s *Server) maybeAutoInit(ctx context.Context, hostID string, conn *ws.Conn // credentials. Silent no-op when the host has nothing on file // (the operator hasn't bound creds to it yet). func (s *Server) pushRepoCredsOnHello(ctx context.Context, hostID string, conn *ws.Conn) { - enc, err := s.deps.Store.GetHostCredentials(ctx, hostID) + enc, err := s.deps.Store.GetHostCredentials(ctx, hostID, store.CredKindRepo) if err != nil { if !errors.Is(err, store.ErrNotFound) { slog.Warn("on-hello: load host creds", "host_id", hostID, "err", err) diff --git a/internal/server/http/host_credentials_test.go b/internal/server/http/host_credentials_test.go index af2d286..a821f97 100644 --- a/internal/server/http/host_credentials_test.go +++ b/internal/server/http/host_credentials_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "testing" "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" ) // TestEnrollmentTransfersRepoCreds verifies the round-trip: @@ -57,12 +59,12 @@ func TestEnrollmentTransfersRepoCreds(t *testing.T) { hostID, "host42", "linux", "amd64", "2026-01-01T00:00:00Z"); err != nil { t.Fatalf("insert host: %v", err) } - if err := st.SetHostCredentials(ctx, hostID, encForHost); err != nil { + if err := st.SetHostCredentials(ctx, hostID, store.CredKindRepo, encForHost); err != nil { t.Fatalf("set host credentials: %v", err) } // host_credentials row should now hold the host-bound ciphertext. - got, err := st.GetHostCredentials(ctx, hostID) + got, err := st.GetHostCredentials(ctx, hostID, store.CredKindRepo) if err != nil { t.Fatalf("get host creds: %v", err) } diff --git a/internal/server/http/p2r01_ws_test.go b/internal/server/http/p2r01_ws_test.go index bc3c57a..23bb9a0 100644 --- a/internal/server/http/p2r01_ws_test.go +++ b/internal/server/http/p2r01_ws_test.go @@ -99,7 +99,7 @@ func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (ho if err != nil { t.Fatalf("encrypt: %v", err) } - if err := st.SetHostCredentials(context.Background(), hostID, enc); err != nil { + if err := st.SetHostCredentials(context.Background(), hostID, store.CredKindRepo, enc); err != nil { t.Fatalf("set creds: %v", err) } return hostID, token diff --git a/internal/server/http/ui_repo.go b/internal/server/http/ui_repo.go index 79ad2ae..a8c9136 100644 --- a/internal/server/http/ui_repo.go +++ b/internal/server/http/ui_repo.go @@ -61,7 +61,7 @@ func (s *Server) loadHostRepoPage(r *stdhttp.Request, host store.Host) (*hostRep } // Credentials (redacted). - enc, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID) + enc, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID, store.CredKindRepo) switch { case err == nil: plain, derr := s.deps.AEAD.Decrypt(enc, []byte("host:"+host.ID)) @@ -204,7 +204,7 @@ func (s *Server) handleUIRepoCredentialsSave(w stdhttp.ResponseWriter, r *stdhtt // Merge with existing blob — same semantics as the JSON PUT. existing := repoCredsBlob{} - if cur, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID); err == nil { + if cur, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID, store.CredKindRepo); err == nil { if plain, derr := s.deps.AEAD.Decrypt(cur, []byte("host:"+host.ID)); derr == nil { _ = json.Unmarshal(plain, &existing) } @@ -227,7 +227,7 @@ func (s *Server) handleUIRepoCredentialsSave(w stdhttp.ResponseWriter, r *stdhtt stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) return } - if err := s.deps.Store.SetHostCredentials(r.Context(), host.ID, enc); err != nil { + if err := s.deps.Store.SetHostCredentials(r.Context(), host.ID, store.CredKindRepo, enc); err != nil { slog.Error("ui repo creds: persist", "err", err) stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError) return diff --git a/internal/store/host_credentials.go b/internal/store/host_credentials.go index 22416c8..2dfbcbe 100644 --- a/internal/store/host_credentials.go +++ b/internal/store/host_credentials.go @@ -8,13 +8,23 @@ import ( "time" ) +// CredentialKind identifies the role of a host_credentials row. +type CredentialKind string + +const ( + // CredKindRepo is the append-only credential used for every backup. + CredKindRepo CredentialKind = "repo" + // CredKindAdmin is the delete-capable credential used for prune. + CredKindAdmin CredentialKind = "admin" +) + // GetHostCredentials returns the AEAD-encrypted repo creds blob for -// the host, or ("", ErrNotFound) if no credential has ever been set. +// the host + kind, or ("", ErrNotFound) if no matching row exists. // The caller decrypts using host_id as AEAD additional data. -func (s *Store) GetHostCredentials(ctx context.Context, hostID string) (string, error) { +func (s *Store) GetHostCredentials(ctx context.Context, hostID string, kind CredentialKind) (string, error) { row := s.db.QueryRowContext(ctx, - `SELECT enc_repo_creds FROM host_credentials WHERE host_id = ?`, - hostID) + `SELECT enc_repo_creds FROM host_credentials WHERE host_id = ? AND kind = ?`, + hostID, string(kind)) var enc string if err := row.Scan(&enc); err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -25,22 +35,35 @@ func (s *Store) GetHostCredentials(ctx context.Context, hostID string) (string, return enc, nil } -// SetHostCredentials replaces the host's encrypted repo creds blob. -// The caller has already encrypted using host_id as additional data. -func (s *Store) SetHostCredentials(ctx context.Context, hostID, encRepoCreds string) error { +// SetHostCredentials replaces the host's encrypted repo creds blob for +// the given kind. The caller has already encrypted using host_id as +// additional data. +func (s *Store) SetHostCredentials(ctx context.Context, hostID string, kind CredentialKind, encRepoCreds string) error { if encRepoCreds == "" { return fmt.Errorf("store: empty enc_repo_creds") } now := time.Now().UTC().Format(time.RFC3339Nano) _, err := s.db.ExecContext(ctx, - `INSERT INTO host_credentials (host_id, enc_repo_creds, updated_at) - VALUES (?, ?, ?) - ON CONFLICT(host_id) DO UPDATE SET + `INSERT INTO host_credentials (host_id, kind, enc_repo_creds, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(host_id, kind) DO UPDATE SET enc_repo_creds = excluded.enc_repo_creds, updated_at = excluded.updated_at`, - hostID, encRepoCreds, now) + hostID, string(kind), encRepoCreds, now) if err != nil { return fmt.Errorf("store: set host credentials: %w", err) } return nil } + +// DeleteHostCredentials removes the credential row for the given host +// and kind. A no-op if the row does not exist. +func (s *Store) DeleteHostCredentials(ctx context.Context, hostID string, kind CredentialKind) error { + _, err := s.db.ExecContext(ctx, + `DELETE FROM host_credentials WHERE host_id = ? AND kind = ?`, + hostID, string(kind)) + if err != nil { + return fmt.Errorf("store: delete host credentials: %w", err) + } + return nil +} diff --git a/internal/store/host_credentials_test.go b/internal/store/host_credentials_test.go new file mode 100644 index 0000000..ddca751 --- /dev/null +++ b/internal/store/host_credentials_test.go @@ -0,0 +1,103 @@ +package store + +import ( + "context" + "errors" + "testing" +) + +// seedHost inserts a minimal host row for testing. +func seedHost(t *testing.T, s *Store, hostID string) { + t.Helper() + _, err := s.DB().Exec( + `INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`, + hostID, hostID, "linux", "amd64", "2026-01-01T00:00:00Z") + if err != nil { + t.Fatalf("seed host %q: %v", hostID, err) + } +} + +func TestHostCredentialsAdminRowSeparate(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + const hostID = "h-creds-test" + seedHost(t, s, hostID) + + const repoBlob = "enc-repo-blob" + const adminBlob = "enc-admin-blob" + + // Set repo creds. + if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, repoBlob); err != nil { + t.Fatalf("set repo creds: %v", err) + } + // Set admin creds. + if err := s.SetHostCredentials(ctx, hostID, CredKindAdmin, adminBlob); err != nil { + t.Fatalf("set admin creds: %v", err) + } + + // Fetch each by kind and assert they differ. + gotRepo, err := s.GetHostCredentials(ctx, hostID, CredKindRepo) + if err != nil { + t.Fatalf("get repo creds: %v", err) + } + gotAdmin, err := s.GetHostCredentials(ctx, hostID, CredKindAdmin) + if err != nil { + t.Fatalf("get admin creds: %v", err) + } + if gotRepo != repoBlob { + t.Errorf("repo creds: got %q, want %q", gotRepo, repoBlob) + } + if gotAdmin != adminBlob { + t.Errorf("admin creds: got %q, want %q", gotAdmin, adminBlob) + } + if gotRepo == gotAdmin { + t.Error("repo and admin blobs must differ") + } + + // Delete admin; repo must be unaffected. + if err := s.DeleteHostCredentials(ctx, hostID, CredKindAdmin); err != nil { + t.Fatalf("delete admin creds: %v", err) + } + if _, err := s.GetHostCredentials(ctx, hostID, CredKindAdmin); !errors.Is(err, ErrNotFound) { + t.Errorf("after delete, expected ErrNotFound for admin; got %v", err) + } + if got, err := s.GetHostCredentials(ctx, hostID, CredKindRepo); err != nil || got != repoBlob { + t.Errorf("repo creds should survive admin delete; got %q, err %v", got, err) + } +} + +func TestHostCredentialsNotFound(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + _, err := s.GetHostCredentials(ctx, "no-such-host", CredKindRepo) + if !errors.Is(err, ErrNotFound) { + t.Errorf("expected ErrNotFound, got %v", err) + } +} + +func TestHostCredentialsUpsert(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + const hostID = "h-upsert-test" + seedHost(t, s, hostID) + + if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, "v1"); err != nil { + t.Fatalf("set v1: %v", err) + } + if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, "v2"); err != nil { + t.Fatalf("set v2 (upsert): %v", err) + } + got, err := s.GetHostCredentials(ctx, hostID, CredKindRepo) + if err != nil { + t.Fatalf("get: %v", err) + } + if got != "v2" { + t.Errorf("expected v2, got %q", got) + } +}