store: P2R-10 schema for source-group + host-default hooks (migration 0010)

Adds pre_hook/post_hook BLOB columns to source_groups and
pre_hook_default/post_hook_default to hosts. Bytes stored verbatim
(AEAD encrypt/decrypt happens at the HTTP layer where the AEAD key
lives). Round-trip tests cover set/clear semantics on both tables.
This commit is contained in:
2026-05-04 10:52:16 +01:00
parent c9b49637d1
commit 18b0bf976d
5 changed files with 190 additions and 11 deletions
+106
View File
@@ -0,0 +1,106 @@
// hooks_test.go — covers the pre/post hook columns added in
// migration 0010 (P2R-10): set + reload roundtrip on both
// source_groups and hosts; nil clears the column.
package store
import (
"context"
"path/filepath"
"testing"
"time"
"github.com/oklog/ulid/v2"
)
func newTestStore(t *testing.T) *Store {
t.Helper()
dir := t.TempDir()
st, err := Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("open store: %v", err)
}
t.Cleanup(func() { _ = st.Close() })
return st
}
func makeHostInStore(t *testing.T, st *Store, name string) string {
t.Helper()
id := ulid.Make().String()
if err := st.CreateHost(context.Background(), Host{
ID: id, Name: name, OS: "linux", Arch: "amd64",
EnrolledAt: time.Now().UTC(),
}, "tokenhash-"+id, ""); err != nil {
t.Fatalf("create host: %v", err)
}
return id
}
func TestSourceGroupHooksRoundTrip(t *testing.T) {
t.Parallel()
st := newTestStore(t)
hostID := makeHostInStore(t, st, "hooks-host")
g := &SourceGroup{
ID: ulid.Make().String(), HostID: hostID, Name: "etc",
PreHook: []byte("ENC-PRE"),
PostHook: []byte("ENC-POST"),
}
if err := st.CreateSourceGroup(context.Background(), g); err != nil {
t.Fatalf("create: %v", err)
}
got, err := st.GetSourceGroup(context.Background(), hostID, g.ID)
if err != nil {
t.Fatalf("get: %v", err)
}
if string(got.PreHook) != "ENC-PRE" {
t.Fatalf("PreHook: got %q, want ENC-PRE", got.PreHook)
}
if string(got.PostHook) != "ENC-POST" {
t.Fatalf("PostHook: got %q, want ENC-POST", got.PostHook)
}
// Update: clear PreHook, change PostHook.
got.PreHook = nil
got.PostHook = []byte("ENC-POST-2")
if err := st.UpdateSourceGroup(context.Background(), got); err != nil {
t.Fatalf("update: %v", err)
}
got, err = st.GetSourceGroup(context.Background(), hostID, g.ID)
if err != nil {
t.Fatalf("get: %v", err)
}
if got.PreHook != nil {
t.Fatalf("PreHook: want nil after clear, got %q", got.PreHook)
}
if string(got.PostHook) != "ENC-POST-2" {
t.Fatalf("PostHook: got %q, want ENC-POST-2", got.PostHook)
}
}
func TestHostHookDefaultsRoundTrip(t *testing.T) {
t.Parallel()
st := newTestStore(t)
hostID := makeHostInStore(t, st, "host-hooks-host")
if err := st.SetHostHooks(context.Background(), hostID, []byte("PRE"), []byte("POST")); err != nil {
t.Fatalf("set: %v", err)
}
h, err := st.GetHost(context.Background(), hostID)
if err != nil {
t.Fatalf("get: %v", err)
}
if string(h.PreHookDefault) != "PRE" || string(h.PostHookDefault) != "POST" {
t.Fatalf("after set: pre=%q post=%q", h.PreHookDefault, h.PostHookDefault)
}
// Clear by passing nil.
if err := st.SetHostHooks(context.Background(), hostID, nil, nil); err != nil {
t.Fatalf("clear: %v", err)
}
h, err = st.GetHost(context.Background(), hostID)
if err != nil {
t.Fatalf("get: %v", err)
}
if h.PreHookDefault != nil || h.PostHookDefault != nil {
t.Fatalf("after clear: pre=%v post=%v (want nil)", h.PreHookDefault, h.PostHookDefault)
}
}
+24 -4
View File
@@ -42,7 +42,8 @@ func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*
enrolled_at, last_seen_at, status, repo_id, tags, enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status, current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count, repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default
FROM hosts WHERE agent_token_hash = ?`, FROM hosts WHERE agent_token_hash = ?`,
tokenHash) tokenHash)
return scanHost(row) return scanHost(row)
@@ -55,7 +56,8 @@ func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
enrolled_at, last_seen_at, status, repo_id, tags, enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status, current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count, repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default
FROM hosts WHERE id = ?`, id) FROM hosts WHERE id = ?`, id)
return scanHost(row) return scanHost(row)
} }
@@ -116,7 +118,8 @@ func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
enrolled_at, last_seen_at, status, repo_id, tags, enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status, current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count, repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default
FROM hosts ORDER BY name`) FROM hosts ORDER BY name`)
if err != nil { if err != nil {
return nil, fmt.Errorf("store: list hosts: %w", err) return nil, fmt.Errorf("store: list hosts: %w", err)
@@ -155,13 +158,15 @@ func scanHostRow(s hostScanner) (*Host, error) {
enrolled string enrolled string
tags string tags string
bwUp, bwDown sql.NullInt64 bwUp, bwDown sql.NullInt64
preHook, postHook []byte
) )
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch, err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion, &h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
&enrolled, &lastSeen, &h.Status, &repoID, &tags, &enrolled, &lastSeen, &h.Status, &repoID, &tags,
&currentJob, &lastBackupAt, &lastBkSt, &currentJob, &lastBackupAt, &lastBkSt,
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount, &h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
&h.AppliedScheduleVersion, &bwUp, &bwDown) &h.AppliedScheduleVersion, &bwUp, &bwDown,
&preHook, &postHook)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
@@ -210,9 +215,24 @@ func scanHostRow(s hostScanner) (*Host, error) {
v := int(bwDown.Int64) v := int(bwDown.Int64)
h.BandwidthDownKBps = &v h.BandwidthDownKBps = &v
} }
h.PreHookDefault = preHook
h.PostHookDefault = postHook
return &h, nil return &h, nil
} }
// SetHostHooks replaces the host-wide pre/post hook defaults. Pass
// nil/empty to clear that hook. Stored verbatim — caller is expected
// to encrypt the bytes before they reach this layer.
func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post []byte) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`,
nullableBytes(pre), nullableBytes(post), hostID)
if err != nil {
return fmt.Errorf("store: set host hooks: %w", err)
}
return nil
}
// SetHostBandwidth replaces the host's upload/download caps. Pass nil // SetHostBandwidth replaces the host's upload/download caps. Pass nil
// to clear a cap. Caller decides validation; non-positive caps are // to clear a cap. Caller decides validation; non-positive caps are
// treated as "no cap" by the agent regardless. // treated as "no cap" by the agent regardless.
+25
View File
@@ -0,0 +1,25 @@
-- 0010_hooks.sql
--
-- P2R-10: pre/post hooks on source groups + host-wide defaults.
--
-- Hook bodies are stored as AEAD ciphertext (existing crypto.AEAD)
-- because operators do put credentials in shell snippets — even
-- though we tell them not to. NULL means "no hook configured".
--
-- Hooks fire only for kind=backup jobs. forget/prune/check/unlock
-- skip them per spec.md §14.3 (P2R-11 enforces this in the agent
-- dispatcher).
--
-- Resolution order at dispatch time:
-- source_group.<phase>_hook (per-group override, AEAD blob)
-- host.<phase>_hook_default (host default, AEAD blob)
-- none → no hook runs
--
-- All four columns are added in-place via ALTER TABLE ADD COLUMN.
-- Per CLAUDE.md the table-rebuild pattern is unsafe with FK cascades.
ALTER TABLE source_groups ADD COLUMN pre_hook BLOB;
ALTER TABLE source_groups ADD COLUMN post_hook BLOB;
ALTER TABLE hosts ADD COLUMN pre_hook_default BLOB;
ALTER TABLE hosts ADD COLUMN post_hook_default BLOB;
+22 -7
View File
@@ -45,13 +45,14 @@ func (st *Store) CreateSourceGroup(ctx context.Context, g *SourceGroup) error {
`INSERT INTO source_groups ( `INSERT INTO source_groups (
id, host_id, name, includes, excludes, retention_policy, id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension, retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at created_at, updated_at, pre_hook, post_hook
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
g.ID, g.HostID, g.Name, g.ID, g.HostID, g.Name,
string(includesJSON), string(excludesJSON), string(retentionJSON), string(includesJSON), string(excludesJSON), string(retentionJSON),
g.RetryMax, g.RetryBackoffSeconds, g.RetryMax, g.RetryBackoffSeconds,
nullableString(g.ConflictDimension), nullableString(g.ConflictDimension),
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
); err != nil { ); err != nil {
return fmt.Errorf("store: create source group: %w", err) return fmt.Errorf("store: create source group: %w", err)
} }
@@ -88,13 +89,14 @@ func (st *Store) UpdateSourceGroup(ctx context.Context, g *SourceGroup) error {
`UPDATE source_groups SET `UPDATE source_groups SET
name = ?, includes = ?, excludes = ?, retention_policy = ?, name = ?, includes = ?, excludes = ?, retention_policy = ?,
retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?, retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?,
updated_at = ? updated_at = ?, pre_hook = ?, post_hook = ?
WHERE id = ? AND host_id = ?`, WHERE id = ? AND host_id = ?`,
g.Name, g.Name,
string(includesJSON), string(excludesJSON), string(retentionJSON), string(includesJSON), string(excludesJSON), string(retentionJSON),
g.RetryMax, g.RetryBackoffSeconds, g.RetryMax, g.RetryBackoffSeconds,
nullableString(g.ConflictDimension), nullableString(g.ConflictDimension),
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
g.ID, g.HostID, g.ID, g.HostID,
) )
if err != nil { if err != nil {
@@ -143,7 +145,7 @@ func (st *Store) GetSourceGroup(ctx context.Context, hostID, groupID string) (*S
row := st.db.QueryRowContext(ctx, row := st.db.QueryRowContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy, `SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension, retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE id = ? AND host_id = ?`, FROM source_groups WHERE id = ? AND host_id = ?`,
groupID, hostID) groupID, hostID)
g, err := scanSourceGroup(row) g, err := scanSourceGroup(row)
@@ -159,7 +161,7 @@ func (st *Store) GetSourceGroupByName(ctx context.Context, hostID, name string)
row := st.db.QueryRowContext(ctx, row := st.db.QueryRowContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy, `SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension, retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE host_id = ? AND name = ?`, FROM source_groups WHERE host_id = ? AND name = ?`,
hostID, name) hostID, name)
g, err := scanSourceGroup(row) g, err := scanSourceGroup(row)
@@ -177,7 +179,7 @@ func (st *Store) ListSourceGroupsByHost(ctx context.Context, hostID string) ([]S
rows, err := st.db.QueryContext(ctx, rows, err := st.db.QueryContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy, `SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension, retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE host_id = ? ORDER BY name`, FROM source_groups WHERE host_id = ? ORDER BY name`,
hostID) hostID)
if err != nil { if err != nil {
@@ -224,14 +226,17 @@ func scanSourceGroupRow(s sourceGroupScanner) (*SourceGroup, error) {
includes, excludes, retention string includes, excludes, retention string
conflict sql.NullString conflict sql.NullString
createdAt, updatedAt string createdAt, updatedAt string
preHook, postHook []byte
) )
err := s.Scan(&out.ID, &out.HostID, &out.Name, err := s.Scan(&out.ID, &out.HostID, &out.Name,
&includes, &excludes, &retention, &includes, &excludes, &retention,
&out.RetryMax, &out.RetryBackoffSeconds, &conflict, &out.RetryMax, &out.RetryBackoffSeconds, &conflict,
&createdAt, &updatedAt) &createdAt, &updatedAt, &preHook, &postHook)
if err != nil { if err != nil {
return nil, err return nil, err
} }
out.PreHook = preHook
out.PostHook = postHook
if includes != "" { if includes != "" {
_ = json.Unmarshal([]byte(includes), &out.Includes) _ = json.Unmarshal([]byte(includes), &out.Includes)
} }
@@ -259,3 +264,13 @@ func nullableString(s string) any {
} }
return s return s
} }
// nullableBytes returns nil for an empty/nil slice so SQL stores it
// as NULL rather than an empty BLOB. The agent treats both the same
// (no hook), but NULL is the canonical "absent" form on disk.
func nullableBytes(b []byte) any {
if len(b) == 0 {
return nil
}
return b
}
+13
View File
@@ -66,6 +66,12 @@ type Host struct {
// (backup, restore, prune). nil = no cap. // (backup, restore, prune). nil = no cap.
BandwidthUpKBps *int BandwidthUpKBps *int
BandwidthDownKBps *int BandwidthDownKBps *int
// PreHookDefault / PostHookDefault are AEAD-encrypted host-wide
// hook bodies. Per source group hooks (SourceGroup.PreHook /
// PostHook) override these when set. nil = no default configured.
PreHookDefault []byte
PostHookDefault []byte
} }
// Schedule is now intentionally slim: cron + which groups + enabled. // Schedule is now intentionally slim: cron + which groups + enabled.
@@ -106,6 +112,13 @@ type SourceGroup struct {
ConflictDimension string ConflictDimension string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
// PreHook / PostHook are AEAD-encrypted shell snippets (raw blob).
// nil means "no hook configured." Encryption/decryption happens at
// the HTTP layer (where AEAD lives); the store layer just persists
// the bytes verbatim.
PreHook []byte
PostHook []byte
} }
// RetentionPolicy is the typed view of `restic forget --keep-*`. // RetentionPolicy is the typed view of `restic forget --keep-*`.