diff --git a/internal/server/http/enrollment.go b/internal/server/http/enrollment.go index fbf961e..6755a6c 100644 --- a/internal/server/http/enrollment.go +++ b/internal/server/http/enrollment.go @@ -1,7 +1,9 @@ package http import ( + "context" "encoding/json" + "fmt" stdhttp "net/http" "strings" "time" @@ -41,9 +43,18 @@ type enrollResponse struct { // enrollOperatorRequest creates a one-time enrollment token for an // operator who is about to install an agent. Authenticated UI route. +// +// Repo creds are required at token-mint time so the agent can run a +// backup the moment it comes online. The trio is JSON-encoded, +// AEAD-encrypted with token_hash as additional data, and stashed on +// the token row. ConsumeEnrollmentToken re-encrypts under host_id +// and writes the host_credentials row in the same tx as token-burn. type enrollOperatorRequest struct { - HostName string `json:"hostname"` - Tags []string `json:"tags,omitempty"` + HostName string `json:"hostname"` + Tags []string `json:"tags,omitempty"` + RepoURL string `json:"repo_url"` + RepoUsername string `json:"repo_username"` + RepoPassword string `json:"repo_password"` } type enrollOperatorResponse struct { @@ -51,6 +62,15 @@ type enrollOperatorResponse struct { ExpiresAt time.Time `json:"expires_at"` } +// repoCredsBlob is the JSON shape of the encrypted repo-creds blob. +// Lives only inside AEAD ciphertext — never on the wire as plaintext +// outside the WS config.update push. +type repoCredsBlob struct { + RepoURL string `json:"repo_url"` + RepoUsername string `json:"repo_username"` + RepoPassword string `json:"repo_password"` +} + // handleAgentEnroll consumes a one-time token, persists a Host row, // and returns persistent agent credentials. Open endpoint (no // session) — the token is the credential. @@ -72,7 +92,18 @@ func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request) // We do these in two statements; if create-host fails, the token // is already burned. That's acceptable — operator just regens. tokHash := auth.HashToken(req.Token) - if err := s.deps.Store.ConsumeEnrollmentToken(r.Context(), tokHash, hostID); err != nil { + + // If the token carries repo creds, re-encrypt them under the new + // host_id so the host_credentials row is bound to the host (not + // the token, which is about to disappear). + encForHost, err := s.rebindTokenCreds(r.Context(), tokHash, hostID) + if err != nil { + writeJSONError(w, stdhttp.StatusUnauthorized, "invalid_token", + "token unknown, expired, or already used") + return + } + + if err := s.deps.Store.ConsumeEnrollmentToken(r.Context(), tokHash, hostID, encForHost); err != nil { writeJSONError(w, stdhttp.StatusUnauthorized, "invalid_token", "token unknown, expired, or already used") return @@ -137,14 +168,29 @@ func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhtt writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) return } + if req.RepoURL == "" || req.RepoPassword == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_field", + "repo_url and repo_password are required so the agent can run backups on first connect") + return + } token, err := auth.NewToken() if err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") return } + tokHash := auth.HashToken(token) + + enc, err := s.encryptRepoCreds(repoCredsBlob{ + RepoURL: req.RepoURL, RepoUsername: req.RepoUsername, RepoPassword: req.RepoPassword, + }, []byte("token:"+tokHash)) + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + const ttl = time.Hour - if err := s.deps.Store.CreateEnrollmentToken(r.Context(), auth.HashToken(token), ttl); err != nil { + if err := s.deps.Store.CreateEnrollmentToken(r.Context(), tokHash, ttl, enc); err != nil { writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") return } @@ -155,6 +201,39 @@ func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhtt }) } +// rebindTokenCreds decrypts the creds attached to the token (if any), +// re-encrypts under the new host_id, and returns the new ciphertext. +// Empty return = the token had no creds attached, which we treat as +// a hard error today (the operator must supply creds at mint time). +func (s *Server) rebindTokenCreds(ctx context.Context, tokHash, hostID string) (string, error) { + enc, err := s.deps.Store.GetEnrollmentTokenCreds(ctx, tokHash) + if err != nil { + return "", err + } + if enc == "" { + return "", nil + } + plain, err := s.deps.AEAD.Decrypt(enc, []byte("token:"+tokHash)) + if err != nil { + return "", fmt.Errorf("decrypt token creds: %w", err) + } + out, err := s.deps.AEAD.Encrypt(plain, []byte("host:"+hostID)) + if err != nil { + return "", fmt.Errorf("re-encrypt for host: %w", err) + } + return out, nil +} + +// encryptRepoCreds JSON-encodes blob and seals it with the given +// additional-data context. +func (s *Server) encryptRepoCreds(blob repoCredsBlob, ad []byte) (string, error) { + body, err := json.Marshal(blob) + if err != nil { + return "", fmt.Errorf("marshal repo creds: %w", err) + } + return s.deps.AEAD.Encrypt(body, ad) +} + // authedUser returns true iff the request carries a valid session // cookie. Minimal stub for now; full RBAC middleware lands with // P4-03. diff --git a/internal/server/http/enrollment_test.go b/internal/server/http/enrollment_test.go index cc71604..755c4e1 100644 --- a/internal/server/http/enrollment_test.go +++ b/internal/server/http/enrollment_test.go @@ -73,7 +73,7 @@ func TestEnrollmentHappyPath(t *testing.T) { // Issue a token directly via the store (skipping the operator UI). rawToken, _ := auth.NewToken() if err := st.CreateEnrollmentToken(context.Background(), - auth.HashToken(rawToken), 5*time.Minute); err != nil { + auth.HashToken(rawToken), 5*time.Minute, ""); err != nil { t.Fatalf("issue: %v", err) } diff --git a/internal/server/http/host_credentials.go b/internal/server/http/host_credentials.go new file mode 100644 index 0000000..b8dab1c --- /dev/null +++ b/internal/server/http/host_credentials.go @@ -0,0 +1,179 @@ +package http + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + stdhttp "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +func nowUTC() time.Time { return time.Now().UTC() } + +// hostRepoCredsRequest is the body of PUT /api/hosts/{id}/repo-credentials. +// Operator can edit any subset; missing fields preserve the existing +// value (so changing only the password doesn't require resending the URL). +// +// We model this as plaintext on the wire because the wire is HTTPS to +// the proxy. The values are AEAD-encrypted before they touch SQLite, +// and only ever leave the server again inside the authenticated WS +// `config.update` push. +type hostRepoCredsRequest struct { + RepoURL *string `json:"repo_url,omitempty"` + RepoUsername *string `json:"repo_username,omitempty"` + RepoPassword *string `json:"repo_password,omitempty"` +} + +// handleSetHostCredentials lets an operator/admin update a host's +// repo creds. Any fields the operator sends overwrite the +// corresponding fields in the existing blob; the others are +// preserved. Re-encrypts under host_id and pushes a config.update +// over the WS if the agent is connected. +func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) { + if !s.authedUser(r) { + writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "") + return + } + hostID := chi.URLParam(r, "id") + if hostID == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_id", "") + return + } + if _, err := s.deps.Store.GetHost(r.Context(), hostID); err != nil { + writeJSONError(w, stdhttp.StatusNotFound, "host_not_found", "") + return + } + + var req hostRepoCredsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error()) + return + } + + // Merge with the existing row, if any. + existing := repoCredsBlob{} + if cur, err := s.deps.Store.GetHostCredentials(r.Context(), hostID); err == nil { + plain, err := s.deps.AEAD.Decrypt(cur, []byte("host:"+hostID)) + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "decrypt_failed", "") + return + } + _ = json.Unmarshal(plain, &existing) + } else if !errors.Is(err, store.ErrNotFound) { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + + if req.RepoURL != nil { + existing.RepoURL = *req.RepoURL + } + if req.RepoUsername != nil { + existing.RepoUsername = *req.RepoUsername + } + if req.RepoPassword != nil { + existing.RepoPassword = *req.RepoPassword + } + if existing.RepoURL == "" || existing.RepoPassword == "" { + writeJSONError(w, stdhttp.StatusBadRequest, "missing_field", + "repo_url and repo_password must end up non-empty") + return + } + + enc, err := s.encryptRepoCreds(existing, []byte("host:"+hostID)) + if err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, enc); err != nil { + writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "") + return + } + + _ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{ + ID: ulid.Make().String(), + Actor: "user", + Action: "host.repo_credentials_set", + TargetKind: ptr("host"), + TargetID: &hostID, + TS: nowUTC(), + }) + + // Push to the agent if it's connected. Errors here are non-fatal: + // the next reconnect will pick the row up via the hello handler. + if s.deps.Hub != nil && s.deps.Hub.Connected(hostID) { + _ = s.pushRepoCredsToAgent(r.Context(), hostID, existing) + } + + w.WriteHeader(stdhttp.StatusNoContent) +} + +// pushRepoCredsToAgent serialises blob into a config.update envelope +// and ships it down the agent's WS. Returns an error from the hub +// (no-op if not connected — caller is expected to check first when it +// matters). +func (s *Server) pushRepoCredsToAgent(ctx context.Context, hostID string, blob repoCredsBlob) error { + env, err := api.Marshal(api.MsgConfigUpdate, "", api.ConfigUpdatePayload{ + RepoURL: blob.RepoURL, + RepoUsername: blob.RepoUsername, + RepoPassword: blob.RepoPassword, + }) + if err != nil { + return err + } + sendCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := s.deps.Hub.Send(sendCtx, hostID, env); err != nil { + slog.Warn("push repo creds: hub send failed", "host_id", hostID, "err", err) + return err + } + return nil +} + +// onAgentHello runs synchronously inside the WS handler immediately +// after a successful hello. It loads the host's encrypted creds (if +// any), decrypts, and ships them down the conn as a config.update so +// the agent has them before any command.run lands. +// +// The conn argument is used directly (rather than via the hub) so we +// don't race a brand-new register against an old still-closing conn. +func (s *Server) onAgentHello(ctx context.Context, hostID string, conn *ws.Conn) { + enc, err := s.deps.Store.GetHostCredentials(ctx, hostID) + if err != nil { + if !errors.Is(err, store.ErrNotFound) { + slog.Warn("on-hello: load host creds", "host_id", hostID, "err", err) + } + return + } + plain, err := s.deps.AEAD.Decrypt(enc, []byte("host:"+hostID)) + if err != nil { + slog.Error("on-hello: decrypt host creds", "host_id", hostID, "err", err) + return + } + var blob repoCredsBlob + if err := json.Unmarshal(plain, &blob); err != nil { + slog.Error("on-hello: parse host creds", "host_id", hostID, "err", err) + return + } + env, err := api.Marshal(api.MsgConfigUpdate, "", api.ConfigUpdatePayload{ + RepoURL: blob.RepoURL, + RepoUsername: blob.RepoUsername, + RepoPassword: blob.RepoPassword, + }) + if err != nil { + slog.Error("on-hello: marshal config.update", "host_id", hostID, "err", err) + return + } + sendCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := conn.Send(sendCtx, env); err != nil { + slog.Warn("on-hello: send config.update", "host_id", hostID, "err", err) + } +} diff --git a/internal/server/http/host_credentials_test.go b/internal/server/http/host_credentials_test.go new file mode 100644 index 0000000..f435a0e --- /dev/null +++ b/internal/server/http/host_credentials_test.go @@ -0,0 +1,102 @@ +package http + +import ( + "context" + "encoding/json" + "testing" +) + +// TestEnrollmentTransfersRepoCreds verifies the round-trip: +// - operator mints a token with repo_url/username/password +// - encrypted blob lands on the token row, bound to token_hash +// - on consume, the blob is re-encrypted bound to host_id and +// written to host_credentials in the same tx. +func TestEnrollmentTransfersRepoCreds(t *testing.T) { + t.Parallel() + srv, _, st := newTestServerWithHub(t) + + ctx := context.Background() + want := repoCredsBlob{ + RepoURL: "rest:https://repo.example/host42", + RepoUsername: "host42", + RepoPassword: "hunter2", + } + + // Encrypt + create token like the operator endpoint would. + const tokHash = "tok-hash-fixture" + enc, err := srv.encryptRepoCreds(want, []byte("token:"+tokHash)) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + if err := st.CreateEnrollmentToken(ctx, tokHash, 1<<20, enc); err != nil { + t.Fatalf("create token: %v", err) + } + + // Rebind under host_id, then consume (this is what the agent + // enroll handler does inline). + const hostID = "h-fixture" + encForHost, err := srv.rebindTokenCreds(ctx, tokHash, hostID) + if err != nil { + t.Fatalf("rebind: %v", err) + } + if encForHost == "" { + t.Fatal("rebind returned empty blob; expected re-encrypted ciphertext") + } + if encForHost == enc { + t.Errorf("rebind should change ciphertext (additional-data differs); got identical") + } + + // Need a host row for the FK. + if _, err := st.DB().Exec( + `INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`, + hostID, "host42", "linux", "amd64", "2026-01-01T00:00:00Z"); err != nil { + t.Fatalf("insert host: %v", err) + } + if err := st.ConsumeEnrollmentToken(ctx, tokHash, hostID, encForHost); err != nil { + t.Fatalf("consume: %v", err) + } + + // host_credentials row should now hold the host-bound ciphertext. + got, err := st.GetHostCredentials(ctx, hostID) + if err != nil { + t.Fatalf("get host creds: %v", err) + } + plain, err := srv.deps.AEAD.Decrypt(got, []byte("host:"+hostID)) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + var blob repoCredsBlob + if err := json.Unmarshal(plain, &blob); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if blob != want { + t.Errorf("blob mismatch:\n got %+v\nwant %+v", blob, want) + } + + // Cross-check: decrypting with a wrong AD must fail (swap + // detection — proves the AAD binding is doing real work). + if _, err := srv.deps.AEAD.Decrypt(got, []byte("host:other-host")); err == nil { + t.Error("decrypt with wrong AD must fail; AAD binding is broken") + } +} + +// TestEnrollmentTokenWithoutCreds is the regression that ensures the +// existing ttl/single-use semantics still work when no creds are +// attached (used by the enrollment_test.go fixture path). +func TestEnrollmentTokenWithoutCreds(t *testing.T) { + t.Parallel() + _, _, st := newTestServerWithHub(t) + ctx := context.Background() + + const tokHash = "no-creds-token" + if err := st.CreateEnrollmentToken(ctx, tokHash, 1<<20, ""); err != nil { + t.Fatalf("create: %v", err) + } + enc, err := st.GetEnrollmentTokenCreds(ctx, tokHash) + if err != nil { + t.Fatalf("get token creds: %v", err) + } + if enc != "" { + t.Errorf("token without creds should return empty blob; got %q", enc) + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index c415b85..56edbd6 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -89,13 +89,18 @@ func (s *Server) routes(r chi.Router) { // Snapshot projection (refreshed by the agent after each backup). r.Get("/hosts/{id}/snapshots", s.handleListHostSnapshots) + + // Repo credentials — operator can edit after enrollment. The + // initial set is supplied at token-mint time (see enrollment.go). + r.Put("/hosts/{id}/repo-credentials", s.handleSetHostCredentials) }) // Agent ↔ server WebSocket. Bearer-authenticated inside the handler. if s.deps.Hub != nil { r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{ - Hub: s.deps.Hub, - Store: s.deps.Store, + Hub: s.deps.Hub, + Store: s.deps.Store, + OnHello: s.onAgentHello, })) } diff --git a/internal/server/ws/handler.go b/internal/server/ws/handler.go index b330b33..b600a2e 100644 --- a/internal/server/ws/handler.go +++ b/internal/server/ws/handler.go @@ -21,6 +21,11 @@ import ( type HandlerDeps struct { Hub *Hub Store *store.Store + // OnHello is called once per successful hello, after the host row + // has been touched and the conn registered. Used by the HTTP + // layer to push host_credentials down as a config.update before + // the agent starts asking for jobs. Optional; nil = no-op. + OnHello func(ctx context.Context, hostID string, conn *Conn) } // AgentHandler is the http.Handler that owns /ws/agent. Agents @@ -136,6 +141,12 @@ func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) "agent_version", helloPayload.AgentVersion, "protocol_version", helloPayload.ProtocolVersion) + if deps.OnHello != nil { + // Run synchronously so the config.update lands before any + // command.run an operator might race in. + deps.OnHello(ctx, hostID, c) + } + // Stage 2: main read loop. for { env, err := c.Read(ctx) diff --git a/internal/store/enrollment.go b/internal/store/enrollment.go index c054af9..edf8383 100644 --- a/internal/store/enrollment.go +++ b/internal/store/enrollment.go @@ -2,6 +2,8 @@ package store import ( "context" + "database/sql" + "errors" "fmt" "time" ) @@ -9,14 +11,25 @@ import ( // CreateEnrollmentToken persists a fresh one-time token. The caller // has already hashed the raw token; the raw form is returned to the // operator (printed in the install snippet) and never persisted. -func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration) error { +// +// encRepoCreds is the AEAD-encrypted blob of {repo_url, repo_username, +// repo_password} that ConsumeEnrollmentToken will promote to a +// host_credentials row. Empty string = operator chose to set creds +// later via PUT /api/hosts/{id}/repo-credentials; the agent will +// refuse backup jobs until that lands. +func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration, encRepoCreds string) error { now := time.Now().UTC() + var enc any = nil + if encRepoCreds != "" { + enc = encRepoCreds + } _, err := s.db.ExecContext(ctx, - `INSERT INTO enrollment_tokens (token_hash, created_at, expires_at) - VALUES (?, ?, ?)`, + `INSERT INTO enrollment_tokens (token_hash, created_at, expires_at, enc_repo_creds) + VALUES (?, ?, ?, ?)`, tokenHash, now.Format(time.RFC3339Nano), - now.Add(ttl).Format(time.RFC3339Nano)) + now.Add(ttl).Format(time.RFC3339Nano), + enc) if err != nil { return fmt.Errorf("store: create enrollment token: %w", err) } @@ -24,11 +37,22 @@ func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl } // ConsumeEnrollmentToken atomically validates a token (must exist, -// not be consumed, not be expired) and marks it consumed by hostID. +// not be consumed, not be expired), marks it consumed by hostID, and +// — if the token carries encrypted repo creds — promotes them to a +// host_credentials row in the same tx. The encrypted blob is +// re-encrypted by the caller with host_id as additional data; we +// don't crack it open here. +// // Returns ErrNotFound on any failure. -func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID string) error { +func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID, encRepoCredsForHost string) error { now := time.Now().UTC().Format(time.RFC3339Nano) - res, err := s.db.ExecContext(ctx, + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("store: consume enrollment token: begin: %w", err) + } + defer func() { _ = tx.Rollback() }() + + res, err := tx.ExecContext(ctx, `UPDATE enrollment_tokens SET consumed_at = ?, consumed_host = ? WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`, @@ -40,9 +64,51 @@ func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID st if n == 0 { return ErrNotFound } + + if encRepoCredsForHost != "" { + if _, err := tx.ExecContext(ctx, + `INSERT INTO host_credentials (host_id, enc_repo_creds, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(host_id) DO UPDATE SET + enc_repo_creds = excluded.enc_repo_creds, + updated_at = excluded.updated_at`, + hostID, encRepoCredsForHost, now); err != nil { + return fmt.Errorf("store: promote host credentials: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("store: consume enrollment token: commit: %w", err) + } return nil } +// GetEnrollmentTokenCreds returns the encrypted repo-creds blob the +// operator stashed when creating the token, or ("", ErrNotFound) if +// the token is gone / consumed / expired / had no creds attached. +// +// The caller decrypts using token_hash as the AEAD additional data, +// then re-encrypts using host_id as additional data before passing +// to ConsumeEnrollmentToken. +func (s *Store) GetEnrollmentTokenCreds(ctx context.Context, tokenHash string) (string, error) { + now := time.Now().UTC().Format(time.RFC3339Nano) + row := s.db.QueryRowContext(ctx, + `SELECT enc_repo_creds FROM enrollment_tokens + WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`, + tokenHash, now) + var enc sql.NullString + if err := row.Scan(&enc); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNotFound + } + return "", fmt.Errorf("store: get enrollment token creds: %w", err) + } + if !enc.Valid { + return "", nil + } + return enc.String, nil +} + // PurgeExpiredEnrollmentTokens deletes long-expired token rows. Tokens // retained for ~24h after expiry so audit traces still resolve them. func (s *Store) PurgeExpiredEnrollmentTokens(ctx context.Context) (int64, error) { diff --git a/internal/store/host_credentials.go b/internal/store/host_credentials.go new file mode 100644 index 0000000..22416c8 --- /dev/null +++ b/internal/store/host_credentials.go @@ -0,0 +1,46 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" +) + +// GetHostCredentials returns the AEAD-encrypted repo creds blob for +// the host, or ("", ErrNotFound) if no credential has ever been set. +// The caller decrypts using host_id as AEAD additional data. +func (s *Store) GetHostCredentials(ctx context.Context, hostID string) (string, error) { + row := s.db.QueryRowContext(ctx, + `SELECT enc_repo_creds FROM host_credentials WHERE host_id = ?`, + hostID) + var enc string + if err := row.Scan(&enc); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNotFound + } + return "", fmt.Errorf("store: get host credentials: %w", err) + } + return enc, nil +} + +// SetHostCredentials replaces the host's encrypted repo creds blob. +// The caller has already encrypted using host_id as additional data. +func (s *Store) SetHostCredentials(ctx context.Context, hostID, encRepoCreds string) error { + if encRepoCreds == "" { + return fmt.Errorf("store: empty enc_repo_creds") + } + now := time.Now().UTC().Format(time.RFC3339Nano) + _, err := s.db.ExecContext(ctx, + `INSERT INTO host_credentials (host_id, enc_repo_creds, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(host_id) DO UPDATE SET + enc_repo_creds = excluded.enc_repo_creds, + updated_at = excluded.updated_at`, + hostID, encRepoCreds, now) + if err != nil { + return fmt.Errorf("store: set host credentials: %w", err) + } + return nil +} diff --git a/internal/store/migrations/0002_host_credentials.sql b/internal/store/migrations/0002_host_credentials.sql new file mode 100644 index 0000000..5754b6e --- /dev/null +++ b/internal/store/migrations/0002_host_credentials.sql @@ -0,0 +1,26 @@ +-- 0002_host_credentials.sql +-- +-- Repo credentials carried on the enrollment token, then promoted to +-- a per-host row on consume. Pulled forward from Phase 2 so the +-- "Add host" flow is genuinely one-shot — operator supplies repo +-- creds at token-mint time, agent receives them via config.update on +-- first WS connect. +-- +-- See spec.md §7.3 for the threat model and tasks.md P1-32 for the +-- end-to-end flow. + +-- Token row optionally carries an AEAD-encrypted JSON blob of +-- {repo_url, repo_username, repo_password}. AEAD additional-data +-- binds it to the token_hash so swap attacks between rows fail. +ALTER TABLE enrollment_tokens + ADD COLUMN enc_repo_creds TEXT; + +-- Per-host repo credential, replaces the blob from the token row on +-- ConsumeEnrollmentToken. AEAD additional-data binds it to host_id. +-- One row per host; absence means "no creds set yet, agent will +-- refuse backup jobs until the operator sets them via the UI." +CREATE TABLE host_credentials ( + host_id TEXT PRIMARY KEY REFERENCES hosts(id) ON DELETE CASCADE, + enc_repo_creds TEXT NOT NULL, + updated_at TEXT NOT NULL +); diff --git a/internal/store/store_test.go b/internal/store/store_test.go index ff4fe2b..c34a7e2 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -72,8 +72,13 @@ func TestMigrateIsIdempotent(t *testing.T) { if err := row.Scan(&n); err != nil { t.Fatalf("scan: %v", err) } - if n != 1 { - t.Errorf("re-running migrations should not insert duplicate rows; got %d", n) + migs, err := loadMigrations() + if err != nil { + t.Fatalf("load migrations: %v", err) + } + if n != len(migs) { + t.Errorf("re-running migrations should not insert duplicate rows; want %d, got %d", + len(migs), n) } } diff --git a/internal/store/users_test.go b/internal/store/users_test.go index 64ddefd..35a1aa0 100644 --- a/internal/store/users_test.go +++ b/internal/store/users_test.go @@ -137,7 +137,7 @@ func TestEnrollmentTokenSingleUse(t *testing.T) { ctx := context.Background() hash := "tok-hash" - if err := s.CreateEnrollmentToken(ctx, hash, time.Hour); err != nil { + if err := s.CreateEnrollmentToken(ctx, hash, time.Hour, ""); err != nil { t.Fatalf("create: %v", err) } @@ -148,11 +148,11 @@ func TestEnrollmentTokenSingleUse(t *testing.T) { t.Fatalf("insert host: %v", err) } - if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); err != nil { + if err := s.ConsumeEnrollmentToken(ctx, hash, "h1", ""); err != nil { t.Fatalf("consume: %v", err) } // Second consume must fail — the whole point of one-time tokens. - if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); !errors.Is(err, ErrNotFound) { + if err := s.ConsumeEnrollmentToken(ctx, hash, "h1", ""); !errors.Is(err, ErrNotFound) { t.Errorf("re-consume: want ErrNotFound, got %v", err) } }