From 18b0bf976dbd6f5a340bbaa32acedb92da396bc1 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Mon, 4 May 2026 10:52:16 +0100 Subject: [PATCH] store: P2R-10 schema for source-group + host-default hooks (migration 0010) Adds pre_hook/post_hook BLOB columns to source_groups and pre_hook_default/post_hook_default to hosts. Bytes stored verbatim (AEAD encrypt/decrypt happens at the HTTP layer where the AEAD key lives). Round-trip tests cover set/clear semantics on both tables. --- internal/store/hooks_test.go | 106 +++++++++++++++++++++++ internal/store/hosts.go | 28 +++++- internal/store/migrations/0010_hooks.sql | 25 ++++++ internal/store/sources.go | 29 +++++-- internal/store/types.go | 13 +++ 5 files changed, 190 insertions(+), 11 deletions(-) create mode 100644 internal/store/hooks_test.go create mode 100644 internal/store/migrations/0010_hooks.sql diff --git a/internal/store/hooks_test.go b/internal/store/hooks_test.go new file mode 100644 index 0000000..18a7864 --- /dev/null +++ b/internal/store/hooks_test.go @@ -0,0 +1,106 @@ +// hooks_test.go — covers the pre/post hook columns added in +// migration 0010 (P2R-10): set + reload roundtrip on both +// source_groups and hosts; nil clears the column. +package store + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/oklog/ulid/v2" +) + +func newTestStore(t *testing.T) *Store { + t.Helper() + dir := t.TempDir() + st, err := Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("open store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + return st +} + +func makeHostInStore(t *testing.T, st *Store, name string) string { + t.Helper() + id := ulid.Make().String() + if err := st.CreateHost(context.Background(), Host{ + ID: id, Name: name, OS: "linux", Arch: "amd64", + EnrolledAt: time.Now().UTC(), + }, "tokenhash-"+id, ""); err != nil { + t.Fatalf("create host: %v", err) + } + return id +} + +func TestSourceGroupHooksRoundTrip(t *testing.T) { + t.Parallel() + st := newTestStore(t) + hostID := makeHostInStore(t, st, "hooks-host") + + g := &SourceGroup{ + ID: ulid.Make().String(), HostID: hostID, Name: "etc", + PreHook: []byte("ENC-PRE"), + PostHook: []byte("ENC-POST"), + } + if err := st.CreateSourceGroup(context.Background(), g); err != nil { + t.Fatalf("create: %v", err) + } + got, err := st.GetSourceGroup(context.Background(), hostID, g.ID) + if err != nil { + t.Fatalf("get: %v", err) + } + if string(got.PreHook) != "ENC-PRE" { + t.Fatalf("PreHook: got %q, want ENC-PRE", got.PreHook) + } + if string(got.PostHook) != "ENC-POST" { + t.Fatalf("PostHook: got %q, want ENC-POST", got.PostHook) + } + + // Update: clear PreHook, change PostHook. + got.PreHook = nil + got.PostHook = []byte("ENC-POST-2") + if err := st.UpdateSourceGroup(context.Background(), got); err != nil { + t.Fatalf("update: %v", err) + } + got, err = st.GetSourceGroup(context.Background(), hostID, g.ID) + if err != nil { + t.Fatalf("get: %v", err) + } + if got.PreHook != nil { + t.Fatalf("PreHook: want nil after clear, got %q", got.PreHook) + } + if string(got.PostHook) != "ENC-POST-2" { + t.Fatalf("PostHook: got %q, want ENC-POST-2", got.PostHook) + } +} + +func TestHostHookDefaultsRoundTrip(t *testing.T) { + t.Parallel() + st := newTestStore(t) + hostID := makeHostInStore(t, st, "host-hooks-host") + + if err := st.SetHostHooks(context.Background(), hostID, []byte("PRE"), []byte("POST")); err != nil { + t.Fatalf("set: %v", err) + } + h, err := st.GetHost(context.Background(), hostID) + if err != nil { + t.Fatalf("get: %v", err) + } + if string(h.PreHookDefault) != "PRE" || string(h.PostHookDefault) != "POST" { + t.Fatalf("after set: pre=%q post=%q", h.PreHookDefault, h.PostHookDefault) + } + // Clear by passing nil. + if err := st.SetHostHooks(context.Background(), hostID, nil, nil); err != nil { + t.Fatalf("clear: %v", err) + } + h, err = st.GetHost(context.Background(), hostID) + if err != nil { + t.Fatalf("get: %v", err) + } + if h.PreHookDefault != nil || h.PostHookDefault != nil { + t.Fatalf("after clear: pre=%v post=%v (want nil)", h.PreHookDefault, h.PostHookDefault) + } +} diff --git a/internal/store/hosts.go b/internal/store/hosts.go index bd6a24d..fc0b383 100644 --- a/internal/store/hosts.go +++ b/internal/store/hosts.go @@ -42,7 +42,8 @@ func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (* enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, - applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps + applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, + pre_hook_default, post_hook_default FROM hosts WHERE agent_token_hash = ?`, tokenHash) return scanHost(row) @@ -55,7 +56,8 @@ func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) { enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, - applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps + applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, + pre_hook_default, post_hook_default FROM hosts WHERE id = ?`, id) return scanHost(row) } @@ -116,7 +118,8 @@ func (s *Store) ListHosts(ctx context.Context) ([]Host, error) { enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, - applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps + applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, + pre_hook_default, post_hook_default FROM hosts ORDER BY name`) if err != nil { return nil, fmt.Errorf("store: list hosts: %w", err) @@ -155,13 +158,15 @@ func scanHostRow(s hostScanner) (*Host, error) { enrolled string tags string bwUp, bwDown sql.NullInt64 + preHook, postHook []byte ) err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch, &h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion, &enrolled, &lastSeen, &h.Status, &repoID, &tags, ¤tJob, &lastBackupAt, &lastBkSt, &h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount, - &h.AppliedScheduleVersion, &bwUp, &bwDown) + &h.AppliedScheduleVersion, &bwUp, &bwDown, + &preHook, &postHook) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound @@ -210,9 +215,24 @@ func scanHostRow(s hostScanner) (*Host, error) { v := int(bwDown.Int64) h.BandwidthDownKBps = &v } + h.PreHookDefault = preHook + h.PostHookDefault = postHook return &h, nil } +// SetHostHooks replaces the host-wide pre/post hook defaults. Pass +// nil/empty to clear that hook. Stored verbatim — caller is expected +// to encrypt the bytes before they reach this layer. +func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post []byte) error { + _, err := s.db.ExecContext(ctx, + `UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`, + nullableBytes(pre), nullableBytes(post), hostID) + if err != nil { + return fmt.Errorf("store: set host hooks: %w", err) + } + return nil +} + // SetHostBandwidth replaces the host's upload/download caps. Pass nil // to clear a cap. Caller decides validation; non-positive caps are // treated as "no cap" by the agent regardless. diff --git a/internal/store/migrations/0010_hooks.sql b/internal/store/migrations/0010_hooks.sql new file mode 100644 index 0000000..f944dcc --- /dev/null +++ b/internal/store/migrations/0010_hooks.sql @@ -0,0 +1,25 @@ +-- 0010_hooks.sql +-- +-- P2R-10: pre/post hooks on source groups + host-wide defaults. +-- +-- Hook bodies are stored as AEAD ciphertext (existing crypto.AEAD) +-- because operators do put credentials in shell snippets — even +-- though we tell them not to. NULL means "no hook configured". +-- +-- Hooks fire only for kind=backup jobs. forget/prune/check/unlock +-- skip them per spec.md §14.3 (P2R-11 enforces this in the agent +-- dispatcher). +-- +-- Resolution order at dispatch time: +-- source_group._hook (per-group override, AEAD blob) +-- host._hook_default (host default, AEAD blob) +-- none → no hook runs +-- +-- All four columns are added in-place via ALTER TABLE ADD COLUMN. +-- Per CLAUDE.md the table-rebuild pattern is unsafe with FK cascades. + +ALTER TABLE source_groups ADD COLUMN pre_hook BLOB; +ALTER TABLE source_groups ADD COLUMN post_hook BLOB; + +ALTER TABLE hosts ADD COLUMN pre_hook_default BLOB; +ALTER TABLE hosts ADD COLUMN post_hook_default BLOB; diff --git a/internal/store/sources.go b/internal/store/sources.go index 6ec3115..164e893 100644 --- a/internal/store/sources.go +++ b/internal/store/sources.go @@ -45,13 +45,14 @@ func (st *Store) CreateSourceGroup(ctx context.Context, g *SourceGroup) error { `INSERT INTO source_groups ( id, host_id, name, includes, excludes, retention_policy, retry_max, retry_backoff_seconds, conflict_dimension, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + created_at, updated_at, pre_hook, post_hook + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, g.ID, g.HostID, g.Name, string(includesJSON), string(excludesJSON), string(retentionJSON), g.RetryMax, g.RetryBackoffSeconds, nullableString(g.ConflictDimension), now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano), + nullableBytes(g.PreHook), nullableBytes(g.PostHook), ); err != nil { return fmt.Errorf("store: create source group: %w", err) } @@ -88,13 +89,14 @@ func (st *Store) UpdateSourceGroup(ctx context.Context, g *SourceGroup) error { `UPDATE source_groups SET name = ?, includes = ?, excludes = ?, retention_policy = ?, retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?, - updated_at = ? + updated_at = ?, pre_hook = ?, post_hook = ? WHERE id = ? AND host_id = ?`, g.Name, string(includesJSON), string(excludesJSON), string(retentionJSON), g.RetryMax, g.RetryBackoffSeconds, nullableString(g.ConflictDimension), now.Format(time.RFC3339Nano), + nullableBytes(g.PreHook), nullableBytes(g.PostHook), g.ID, g.HostID, ) if err != nil { @@ -143,7 +145,7 @@ func (st *Store) GetSourceGroup(ctx context.Context, hostID, groupID string) (*S row := st.db.QueryRowContext(ctx, `SELECT id, host_id, name, includes, excludes, retention_policy, retry_max, retry_backoff_seconds, conflict_dimension, - created_at, updated_at + created_at, updated_at, pre_hook, post_hook FROM source_groups WHERE id = ? AND host_id = ?`, groupID, hostID) g, err := scanSourceGroup(row) @@ -159,7 +161,7 @@ func (st *Store) GetSourceGroupByName(ctx context.Context, hostID, name string) row := st.db.QueryRowContext(ctx, `SELECT id, host_id, name, includes, excludes, retention_policy, retry_max, retry_backoff_seconds, conflict_dimension, - created_at, updated_at + created_at, updated_at, pre_hook, post_hook FROM source_groups WHERE host_id = ? AND name = ?`, hostID, name) g, err := scanSourceGroup(row) @@ -177,7 +179,7 @@ func (st *Store) ListSourceGroupsByHost(ctx context.Context, hostID string) ([]S rows, err := st.db.QueryContext(ctx, `SELECT id, host_id, name, includes, excludes, retention_policy, retry_max, retry_backoff_seconds, conflict_dimension, - created_at, updated_at + created_at, updated_at, pre_hook, post_hook FROM source_groups WHERE host_id = ? ORDER BY name`, hostID) if err != nil { @@ -224,14 +226,17 @@ func scanSourceGroupRow(s sourceGroupScanner) (*SourceGroup, error) { includes, excludes, retention string conflict sql.NullString createdAt, updatedAt string + preHook, postHook []byte ) err := s.Scan(&out.ID, &out.HostID, &out.Name, &includes, &excludes, &retention, &out.RetryMax, &out.RetryBackoffSeconds, &conflict, - &createdAt, &updatedAt) + &createdAt, &updatedAt, &preHook, &postHook) if err != nil { return nil, err } + out.PreHook = preHook + out.PostHook = postHook if includes != "" { _ = json.Unmarshal([]byte(includes), &out.Includes) } @@ -259,3 +264,13 @@ func nullableString(s string) any { } return s } + +// nullableBytes returns nil for an empty/nil slice so SQL stores it +// as NULL rather than an empty BLOB. The agent treats both the same +// (no hook), but NULL is the canonical "absent" form on disk. +func nullableBytes(b []byte) any { + if len(b) == 0 { + return nil + } + return b +} diff --git a/internal/store/types.go b/internal/store/types.go index 6f99f69..4e52d91 100644 --- a/internal/store/types.go +++ b/internal/store/types.go @@ -66,6 +66,12 @@ type Host struct { // (backup, restore, prune). nil = no cap. BandwidthUpKBps *int BandwidthDownKBps *int + + // PreHookDefault / PostHookDefault are AEAD-encrypted host-wide + // hook bodies. Per source group hooks (SourceGroup.PreHook / + // PostHook) override these when set. nil = no default configured. + PreHookDefault []byte + PostHookDefault []byte } // Schedule is now intentionally slim: cron + which groups + enabled. @@ -106,6 +112,13 @@ type SourceGroup struct { ConflictDimension string CreatedAt time.Time UpdatedAt time.Time + + // PreHook / PostHook are AEAD-encrypted shell snippets (raw blob). + // nil means "no hook configured." Encryption/decryption happens at + // the HTTP layer (where AEAD lives); the store layer just persists + // the bytes verbatim. + PreHook []byte + PostHook []byte } // RetentionPolicy is the typed view of `restic forget --keep-*`.