store: P2R-10 schema for source-group + host-default hooks (migration 0010)
Adds pre_hook/post_hook BLOB columns to source_groups and pre_hook_default/post_hook_default to hosts. Bytes stored verbatim (AEAD encrypt/decrypt happens at the HTTP layer where the AEAD key lives). Round-trip tests cover set/clear semantics on both tables.
This commit is contained in:
@@ -0,0 +1,106 @@
|
|||||||
|
// hooks_test.go — covers the pre/post hook columns added in
|
||||||
|
// migration 0010 (P2R-10): set + reload roundtrip on both
|
||||||
|
// source_groups and hosts; nil clears the column.
|
||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestStore(t *testing.T) *Store {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
st, err := Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open store: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = st.Close() })
|
||||||
|
return st
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeHostInStore(t *testing.T, st *Store, name string) string {
|
||||||
|
t.Helper()
|
||||||
|
id := ulid.Make().String()
|
||||||
|
if err := st.CreateHost(context.Background(), Host{
|
||||||
|
ID: id, Name: name, OS: "linux", Arch: "amd64",
|
||||||
|
EnrolledAt: time.Now().UTC(),
|
||||||
|
}, "tokenhash-"+id, ""); err != nil {
|
||||||
|
t.Fatalf("create host: %v", err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSourceGroupHooksRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := newTestStore(t)
|
||||||
|
hostID := makeHostInStore(t, st, "hooks-host")
|
||||||
|
|
||||||
|
g := &SourceGroup{
|
||||||
|
ID: ulid.Make().String(), HostID: hostID, Name: "etc",
|
||||||
|
PreHook: []byte("ENC-PRE"),
|
||||||
|
PostHook: []byte("ENC-POST"),
|
||||||
|
}
|
||||||
|
if err := st.CreateSourceGroup(context.Background(), g); err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
got, err := st.GetSourceGroup(context.Background(), hostID, g.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get: %v", err)
|
||||||
|
}
|
||||||
|
if string(got.PreHook) != "ENC-PRE" {
|
||||||
|
t.Fatalf("PreHook: got %q, want ENC-PRE", got.PreHook)
|
||||||
|
}
|
||||||
|
if string(got.PostHook) != "ENC-POST" {
|
||||||
|
t.Fatalf("PostHook: got %q, want ENC-POST", got.PostHook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update: clear PreHook, change PostHook.
|
||||||
|
got.PreHook = nil
|
||||||
|
got.PostHook = []byte("ENC-POST-2")
|
||||||
|
if err := st.UpdateSourceGroup(context.Background(), got); err != nil {
|
||||||
|
t.Fatalf("update: %v", err)
|
||||||
|
}
|
||||||
|
got, err = st.GetSourceGroup(context.Background(), hostID, g.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get: %v", err)
|
||||||
|
}
|
||||||
|
if got.PreHook != nil {
|
||||||
|
t.Fatalf("PreHook: want nil after clear, got %q", got.PreHook)
|
||||||
|
}
|
||||||
|
if string(got.PostHook) != "ENC-POST-2" {
|
||||||
|
t.Fatalf("PostHook: got %q, want ENC-POST-2", got.PostHook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostHookDefaultsRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
st := newTestStore(t)
|
||||||
|
hostID := makeHostInStore(t, st, "host-hooks-host")
|
||||||
|
|
||||||
|
if err := st.SetHostHooks(context.Background(), hostID, []byte("PRE"), []byte("POST")); err != nil {
|
||||||
|
t.Fatalf("set: %v", err)
|
||||||
|
}
|
||||||
|
h, err := st.GetHost(context.Background(), hostID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get: %v", err)
|
||||||
|
}
|
||||||
|
if string(h.PreHookDefault) != "PRE" || string(h.PostHookDefault) != "POST" {
|
||||||
|
t.Fatalf("after set: pre=%q post=%q", h.PreHookDefault, h.PostHookDefault)
|
||||||
|
}
|
||||||
|
// Clear by passing nil.
|
||||||
|
if err := st.SetHostHooks(context.Background(), hostID, nil, nil); err != nil {
|
||||||
|
t.Fatalf("clear: %v", err)
|
||||||
|
}
|
||||||
|
h, err = st.GetHost(context.Background(), hostID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get: %v", err)
|
||||||
|
}
|
||||||
|
if h.PreHookDefault != nil || h.PostHookDefault != nil {
|
||||||
|
t.Fatalf("after clear: pre=%v post=%v (want nil)", h.PreHookDefault, h.PostHookDefault)
|
||||||
|
}
|
||||||
|
}
|
||||||
+24
-4
@@ -42,7 +42,8 @@ func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*
|
|||||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||||
current_job_id, last_backup_at, last_backup_status,
|
current_job_id, last_backup_at, last_backup_status,
|
||||||
repo_size_bytes, snapshot_count, open_alert_count,
|
repo_size_bytes, snapshot_count, open_alert_count,
|
||||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||||
|
pre_hook_default, post_hook_default
|
||||||
FROM hosts WHERE agent_token_hash = ?`,
|
FROM hosts WHERE agent_token_hash = ?`,
|
||||||
tokenHash)
|
tokenHash)
|
||||||
return scanHost(row)
|
return scanHost(row)
|
||||||
@@ -55,7 +56,8 @@ func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
|
|||||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||||
current_job_id, last_backup_at, last_backup_status,
|
current_job_id, last_backup_at, last_backup_status,
|
||||||
repo_size_bytes, snapshot_count, open_alert_count,
|
repo_size_bytes, snapshot_count, open_alert_count,
|
||||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||||
|
pre_hook_default, post_hook_default
|
||||||
FROM hosts WHERE id = ?`, id)
|
FROM hosts WHERE id = ?`, id)
|
||||||
return scanHost(row)
|
return scanHost(row)
|
||||||
}
|
}
|
||||||
@@ -116,7 +118,8 @@ func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
|
|||||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||||
current_job_id, last_backup_at, last_backup_status,
|
current_job_id, last_backup_at, last_backup_status,
|
||||||
repo_size_bytes, snapshot_count, open_alert_count,
|
repo_size_bytes, snapshot_count, open_alert_count,
|
||||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||||
|
pre_hook_default, post_hook_default
|
||||||
FROM hosts ORDER BY name`)
|
FROM hosts ORDER BY name`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("store: list hosts: %w", err)
|
return nil, fmt.Errorf("store: list hosts: %w", err)
|
||||||
@@ -155,13 +158,15 @@ func scanHostRow(s hostScanner) (*Host, error) {
|
|||||||
enrolled string
|
enrolled string
|
||||||
tags string
|
tags string
|
||||||
bwUp, bwDown sql.NullInt64
|
bwUp, bwDown sql.NullInt64
|
||||||
|
preHook, postHook []byte
|
||||||
)
|
)
|
||||||
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
|
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
|
||||||
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
|
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
|
||||||
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
|
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
|
||||||
¤tJob, &lastBackupAt, &lastBkSt,
|
¤tJob, &lastBackupAt, &lastBkSt,
|
||||||
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
|
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
|
||||||
&h.AppliedScheduleVersion, &bwUp, &bwDown)
|
&h.AppliedScheduleVersion, &bwUp, &bwDown,
|
||||||
|
&preHook, &postHook)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
@@ -210,9 +215,24 @@ func scanHostRow(s hostScanner) (*Host, error) {
|
|||||||
v := int(bwDown.Int64)
|
v := int(bwDown.Int64)
|
||||||
h.BandwidthDownKBps = &v
|
h.BandwidthDownKBps = &v
|
||||||
}
|
}
|
||||||
|
h.PreHookDefault = preHook
|
||||||
|
h.PostHookDefault = postHook
|
||||||
return &h, nil
|
return &h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHostHooks replaces the host-wide pre/post hook defaults. Pass
|
||||||
|
// nil/empty to clear that hook. Stored verbatim — caller is expected
|
||||||
|
// to encrypt the bytes before they reach this layer.
|
||||||
|
func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post []byte) error {
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`,
|
||||||
|
nullableBytes(pre), nullableBytes(post), hostID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store: set host hooks: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetHostBandwidth replaces the host's upload/download caps. Pass nil
|
// SetHostBandwidth replaces the host's upload/download caps. Pass nil
|
||||||
// to clear a cap. Caller decides validation; non-positive caps are
|
// to clear a cap. Caller decides validation; non-positive caps are
|
||||||
// treated as "no cap" by the agent regardless.
|
// treated as "no cap" by the agent regardless.
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
-- 0010_hooks.sql
|
||||||
|
--
|
||||||
|
-- P2R-10: pre/post hooks on source groups + host-wide defaults.
|
||||||
|
--
|
||||||
|
-- Hook bodies are stored as AEAD ciphertext (existing crypto.AEAD)
|
||||||
|
-- because operators do put credentials in shell snippets — even
|
||||||
|
-- though we tell them not to. NULL means "no hook configured".
|
||||||
|
--
|
||||||
|
-- Hooks fire only for kind=backup jobs. forget/prune/check/unlock
|
||||||
|
-- skip them per spec.md §14.3 (P2R-11 enforces this in the agent
|
||||||
|
-- dispatcher).
|
||||||
|
--
|
||||||
|
-- Resolution order at dispatch time:
|
||||||
|
-- source_group.<phase>_hook (per-group override, AEAD blob)
|
||||||
|
-- host.<phase>_hook_default (host default, AEAD blob)
|
||||||
|
-- none → no hook runs
|
||||||
|
--
|
||||||
|
-- All four columns are added in-place via ALTER TABLE ADD COLUMN.
|
||||||
|
-- Per CLAUDE.md the table-rebuild pattern is unsafe with FK cascades.
|
||||||
|
|
||||||
|
ALTER TABLE source_groups ADD COLUMN pre_hook BLOB;
|
||||||
|
ALTER TABLE source_groups ADD COLUMN post_hook BLOB;
|
||||||
|
|
||||||
|
ALTER TABLE hosts ADD COLUMN pre_hook_default BLOB;
|
||||||
|
ALTER TABLE hosts ADD COLUMN post_hook_default BLOB;
|
||||||
@@ -45,13 +45,14 @@ func (st *Store) CreateSourceGroup(ctx context.Context, g *SourceGroup) error {
|
|||||||
`INSERT INTO source_groups (
|
`INSERT INTO source_groups (
|
||||||
id, host_id, name, includes, excludes, retention_policy,
|
id, host_id, name, includes, excludes, retention_policy,
|
||||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||||
created_at, updated_at
|
created_at, updated_at, pre_hook, post_hook
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
g.ID, g.HostID, g.Name,
|
g.ID, g.HostID, g.Name,
|
||||||
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
||||||
g.RetryMax, g.RetryBackoffSeconds,
|
g.RetryMax, g.RetryBackoffSeconds,
|
||||||
nullableString(g.ConflictDimension),
|
nullableString(g.ConflictDimension),
|
||||||
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
|
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
|
||||||
|
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return fmt.Errorf("store: create source group: %w", err)
|
return fmt.Errorf("store: create source group: %w", err)
|
||||||
}
|
}
|
||||||
@@ -88,13 +89,14 @@ func (st *Store) UpdateSourceGroup(ctx context.Context, g *SourceGroup) error {
|
|||||||
`UPDATE source_groups SET
|
`UPDATE source_groups SET
|
||||||
name = ?, includes = ?, excludes = ?, retention_policy = ?,
|
name = ?, includes = ?, excludes = ?, retention_policy = ?,
|
||||||
retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?,
|
retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?,
|
||||||
updated_at = ?
|
updated_at = ?, pre_hook = ?, post_hook = ?
|
||||||
WHERE id = ? AND host_id = ?`,
|
WHERE id = ? AND host_id = ?`,
|
||||||
g.Name,
|
g.Name,
|
||||||
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
||||||
g.RetryMax, g.RetryBackoffSeconds,
|
g.RetryMax, g.RetryBackoffSeconds,
|
||||||
nullableString(g.ConflictDimension),
|
nullableString(g.ConflictDimension),
|
||||||
now.Format(time.RFC3339Nano),
|
now.Format(time.RFC3339Nano),
|
||||||
|
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
|
||||||
g.ID, g.HostID,
|
g.ID, g.HostID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,7 +145,7 @@ func (st *Store) GetSourceGroup(ctx context.Context, hostID, groupID string) (*S
|
|||||||
row := st.db.QueryRowContext(ctx,
|
row := st.db.QueryRowContext(ctx,
|
||||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||||
created_at, updated_at
|
created_at, updated_at, pre_hook, post_hook
|
||||||
FROM source_groups WHERE id = ? AND host_id = ?`,
|
FROM source_groups WHERE id = ? AND host_id = ?`,
|
||||||
groupID, hostID)
|
groupID, hostID)
|
||||||
g, err := scanSourceGroup(row)
|
g, err := scanSourceGroup(row)
|
||||||
@@ -159,7 +161,7 @@ func (st *Store) GetSourceGroupByName(ctx context.Context, hostID, name string)
|
|||||||
row := st.db.QueryRowContext(ctx,
|
row := st.db.QueryRowContext(ctx,
|
||||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||||
created_at, updated_at
|
created_at, updated_at, pre_hook, post_hook
|
||||||
FROM source_groups WHERE host_id = ? AND name = ?`,
|
FROM source_groups WHERE host_id = ? AND name = ?`,
|
||||||
hostID, name)
|
hostID, name)
|
||||||
g, err := scanSourceGroup(row)
|
g, err := scanSourceGroup(row)
|
||||||
@@ -177,7 +179,7 @@ func (st *Store) ListSourceGroupsByHost(ctx context.Context, hostID string) ([]S
|
|||||||
rows, err := st.db.QueryContext(ctx,
|
rows, err := st.db.QueryContext(ctx,
|
||||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||||
created_at, updated_at
|
created_at, updated_at, pre_hook, post_hook
|
||||||
FROM source_groups WHERE host_id = ? ORDER BY name`,
|
FROM source_groups WHERE host_id = ? ORDER BY name`,
|
||||||
hostID)
|
hostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -224,14 +226,17 @@ func scanSourceGroupRow(s sourceGroupScanner) (*SourceGroup, error) {
|
|||||||
includes, excludes, retention string
|
includes, excludes, retention string
|
||||||
conflict sql.NullString
|
conflict sql.NullString
|
||||||
createdAt, updatedAt string
|
createdAt, updatedAt string
|
||||||
|
preHook, postHook []byte
|
||||||
)
|
)
|
||||||
err := s.Scan(&out.ID, &out.HostID, &out.Name,
|
err := s.Scan(&out.ID, &out.HostID, &out.Name,
|
||||||
&includes, &excludes, &retention,
|
&includes, &excludes, &retention,
|
||||||
&out.RetryMax, &out.RetryBackoffSeconds, &conflict,
|
&out.RetryMax, &out.RetryBackoffSeconds, &conflict,
|
||||||
&createdAt, &updatedAt)
|
&createdAt, &updatedAt, &preHook, &postHook)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
out.PreHook = preHook
|
||||||
|
out.PostHook = postHook
|
||||||
if includes != "" {
|
if includes != "" {
|
||||||
_ = json.Unmarshal([]byte(includes), &out.Includes)
|
_ = json.Unmarshal([]byte(includes), &out.Includes)
|
||||||
}
|
}
|
||||||
@@ -259,3 +264,13 @@ func nullableString(s string) any {
|
|||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nullableBytes returns nil for an empty/nil slice so SQL stores it
|
||||||
|
// as NULL rather than an empty BLOB. The agent treats both the same
|
||||||
|
// (no hook), but NULL is the canonical "absent" form on disk.
|
||||||
|
func nullableBytes(b []byte) any {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,6 +66,12 @@ type Host struct {
|
|||||||
// (backup, restore, prune). nil = no cap.
|
// (backup, restore, prune). nil = no cap.
|
||||||
BandwidthUpKBps *int
|
BandwidthUpKBps *int
|
||||||
BandwidthDownKBps *int
|
BandwidthDownKBps *int
|
||||||
|
|
||||||
|
// PreHookDefault / PostHookDefault are AEAD-encrypted host-wide
|
||||||
|
// hook bodies. Per source group hooks (SourceGroup.PreHook /
|
||||||
|
// PostHook) override these when set. nil = no default configured.
|
||||||
|
PreHookDefault []byte
|
||||||
|
PostHookDefault []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Schedule is now intentionally slim: cron + which groups + enabled.
|
// Schedule is now intentionally slim: cron + which groups + enabled.
|
||||||
@@ -106,6 +112,13 @@ type SourceGroup struct {
|
|||||||
ConflictDimension string
|
ConflictDimension string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
// PreHook / PostHook are AEAD-encrypted shell snippets (raw blob).
|
||||||
|
// nil means "no hook configured." Encryption/decryption happens at
|
||||||
|
// the HTTP layer (where AEAD lives); the store layer just persists
|
||||||
|
// the bytes verbatim.
|
||||||
|
PreHook []byte
|
||||||
|
PostHook []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetentionPolicy is the typed view of `restic forget --keep-*`.
|
// RetentionPolicy is the typed view of `restic forget --keep-*`.
|
||||||
|
|||||||
Reference in New Issue
Block a user