store: P2R-10 schema for source-group + host-default hooks (migration 0010)
Adds pre_hook/post_hook BLOB columns to source_groups and pre_hook_default/post_hook_default to hosts. Bytes stored verbatim (AEAD encrypt/decrypt happens at the HTTP layer where the AEAD key lives). Round-trip tests cover set/clear semantics on both tables.
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
// hooks_test.go — covers the pre/post hook columns added in
|
||||
// migration 0010 (P2R-10): set + reload roundtrip on both
|
||||
// source_groups and hosts; nil clears the column.
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, err := Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
return st
|
||||
}
|
||||
|
||||
func makeHostInStore(t *testing.T, st *Store, name string) string {
|
||||
t.Helper()
|
||||
id := ulid.Make().String()
|
||||
if err := st.CreateHost(context.Background(), Host{
|
||||
ID: id, Name: name, OS: "linux", Arch: "amd64",
|
||||
EnrolledAt: time.Now().UTC(),
|
||||
}, "tokenhash-"+id, ""); err != nil {
|
||||
t.Fatalf("create host: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestSourceGroupHooksRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newTestStore(t)
|
||||
hostID := makeHostInStore(t, st, "hooks-host")
|
||||
|
||||
g := &SourceGroup{
|
||||
ID: ulid.Make().String(), HostID: hostID, Name: "etc",
|
||||
PreHook: []byte("ENC-PRE"),
|
||||
PostHook: []byte("ENC-POST"),
|
||||
}
|
||||
if err := st.CreateSourceGroup(context.Background(), g); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
got, err := st.GetSourceGroup(context.Background(), hostID, g.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if string(got.PreHook) != "ENC-PRE" {
|
||||
t.Fatalf("PreHook: got %q, want ENC-PRE", got.PreHook)
|
||||
}
|
||||
if string(got.PostHook) != "ENC-POST" {
|
||||
t.Fatalf("PostHook: got %q, want ENC-POST", got.PostHook)
|
||||
}
|
||||
|
||||
// Update: clear PreHook, change PostHook.
|
||||
got.PreHook = nil
|
||||
got.PostHook = []byte("ENC-POST-2")
|
||||
if err := st.UpdateSourceGroup(context.Background(), got); err != nil {
|
||||
t.Fatalf("update: %v", err)
|
||||
}
|
||||
got, err = st.GetSourceGroup(context.Background(), hostID, g.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got.PreHook != nil {
|
||||
t.Fatalf("PreHook: want nil after clear, got %q", got.PreHook)
|
||||
}
|
||||
if string(got.PostHook) != "ENC-POST-2" {
|
||||
t.Fatalf("PostHook: got %q, want ENC-POST-2", got.PostHook)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostHookDefaultsRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
st := newTestStore(t)
|
||||
hostID := makeHostInStore(t, st, "host-hooks-host")
|
||||
|
||||
if err := st.SetHostHooks(context.Background(), hostID, []byte("PRE"), []byte("POST")); err != nil {
|
||||
t.Fatalf("set: %v", err)
|
||||
}
|
||||
h, err := st.GetHost(context.Background(), hostID)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if string(h.PreHookDefault) != "PRE" || string(h.PostHookDefault) != "POST" {
|
||||
t.Fatalf("after set: pre=%q post=%q", h.PreHookDefault, h.PostHookDefault)
|
||||
}
|
||||
// Clear by passing nil.
|
||||
if err := st.SetHostHooks(context.Background(), hostID, nil, nil); err != nil {
|
||||
t.Fatalf("clear: %v", err)
|
||||
}
|
||||
h, err = st.GetHost(context.Background(), hostID)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if h.PreHookDefault != nil || h.PostHookDefault != nil {
|
||||
t.Fatalf("after clear: pre=%v post=%v (want nil)", h.PreHookDefault, h.PostHookDefault)
|
||||
}
|
||||
}
|
||||
+24
-4
@@ -42,7 +42,8 @@ func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||
pre_hook_default, post_hook_default
|
||||
FROM hosts WHERE agent_token_hash = ?`,
|
||||
tokenHash)
|
||||
return scanHost(row)
|
||||
@@ -55,7 +56,8 @@ func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||
pre_hook_default, post_hook_default
|
||||
FROM hosts WHERE id = ?`, id)
|
||||
return scanHost(row)
|
||||
}
|
||||
@@ -116,7 +118,8 @@ func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
enrolled_at, last_seen_at, status, repo_id, tags,
|
||||
current_job_id, last_backup_at, last_backup_status,
|
||||
repo_size_bytes, snapshot_count, open_alert_count,
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps
|
||||
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
|
||||
pre_hook_default, post_hook_default
|
||||
FROM hosts ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: list hosts: %w", err)
|
||||
@@ -155,13 +158,15 @@ func scanHostRow(s hostScanner) (*Host, error) {
|
||||
enrolled string
|
||||
tags string
|
||||
bwUp, bwDown sql.NullInt64
|
||||
preHook, postHook []byte
|
||||
)
|
||||
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
|
||||
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
|
||||
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
|
||||
¤tJob, &lastBackupAt, &lastBkSt,
|
||||
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
|
||||
&h.AppliedScheduleVersion, &bwUp, &bwDown)
|
||||
&h.AppliedScheduleVersion, &bwUp, &bwDown,
|
||||
&preHook, &postHook)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
@@ -210,9 +215,24 @@ func scanHostRow(s hostScanner) (*Host, error) {
|
||||
v := int(bwDown.Int64)
|
||||
h.BandwidthDownKBps = &v
|
||||
}
|
||||
h.PreHookDefault = preHook
|
||||
h.PostHookDefault = postHook
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// SetHostHooks replaces the host-wide pre/post hook defaults. Pass
|
||||
// nil/empty to clear that hook. Stored verbatim — caller is expected
|
||||
// to encrypt the bytes before they reach this layer.
|
||||
func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post []byte) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`,
|
||||
nullableBytes(pre), nullableBytes(post), hostID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: set host hooks: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHostBandwidth replaces the host's upload/download caps. Pass nil
|
||||
// to clear a cap. Caller decides validation; non-positive caps are
|
||||
// treated as "no cap" by the agent regardless.
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
-- 0010_hooks.sql
|
||||
--
|
||||
-- P2R-10: pre/post hooks on source groups + host-wide defaults.
|
||||
--
|
||||
-- Hook bodies are stored as AEAD ciphertext (existing crypto.AEAD)
|
||||
-- because operators do put credentials in shell snippets — even
|
||||
-- though we tell them not to. NULL means "no hook configured".
|
||||
--
|
||||
-- Hooks fire only for kind=backup jobs. forget/prune/check/unlock
|
||||
-- skip them per spec.md §14.3 (P2R-11 enforces this in the agent
|
||||
-- dispatcher).
|
||||
--
|
||||
-- Resolution order at dispatch time:
|
||||
-- source_group.<phase>_hook (per-group override, AEAD blob)
|
||||
-- host.<phase>_hook_default (host default, AEAD blob)
|
||||
-- none → no hook runs
|
||||
--
|
||||
-- All four columns are added in-place via ALTER TABLE ADD COLUMN.
|
||||
-- Per CLAUDE.md the table-rebuild pattern is unsafe with FK cascades.
|
||||
|
||||
ALTER TABLE source_groups ADD COLUMN pre_hook BLOB;
|
||||
ALTER TABLE source_groups ADD COLUMN post_hook BLOB;
|
||||
|
||||
ALTER TABLE hosts ADD COLUMN pre_hook_default BLOB;
|
||||
ALTER TABLE hosts ADD COLUMN post_hook_default BLOB;
|
||||
@@ -45,13 +45,14 @@ func (st *Store) CreateSourceGroup(ctx context.Context, g *SourceGroup) error {
|
||||
`INSERT INTO source_groups (
|
||||
id, host_id, name, includes, excludes, retention_policy,
|
||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
created_at, updated_at, pre_hook, post_hook
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
g.ID, g.HostID, g.Name,
|
||||
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
||||
g.RetryMax, g.RetryBackoffSeconds,
|
||||
nullableString(g.ConflictDimension),
|
||||
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
|
||||
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
|
||||
); err != nil {
|
||||
return fmt.Errorf("store: create source group: %w", err)
|
||||
}
|
||||
@@ -88,13 +89,14 @@ func (st *Store) UpdateSourceGroup(ctx context.Context, g *SourceGroup) error {
|
||||
`UPDATE source_groups SET
|
||||
name = ?, includes = ?, excludes = ?, retention_policy = ?,
|
||||
retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?,
|
||||
updated_at = ?
|
||||
updated_at = ?, pre_hook = ?, post_hook = ?
|
||||
WHERE id = ? AND host_id = ?`,
|
||||
g.Name,
|
||||
string(includesJSON), string(excludesJSON), string(retentionJSON),
|
||||
g.RetryMax, g.RetryBackoffSeconds,
|
||||
nullableString(g.ConflictDimension),
|
||||
now.Format(time.RFC3339Nano),
|
||||
nullableBytes(g.PreHook), nullableBytes(g.PostHook),
|
||||
g.ID, g.HostID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -143,7 +145,7 @@ func (st *Store) GetSourceGroup(ctx context.Context, hostID, groupID string) (*S
|
||||
row := st.db.QueryRowContext(ctx,
|
||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||
created_at, updated_at
|
||||
created_at, updated_at, pre_hook, post_hook
|
||||
FROM source_groups WHERE id = ? AND host_id = ?`,
|
||||
groupID, hostID)
|
||||
g, err := scanSourceGroup(row)
|
||||
@@ -159,7 +161,7 @@ func (st *Store) GetSourceGroupByName(ctx context.Context, hostID, name string)
|
||||
row := st.db.QueryRowContext(ctx,
|
||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||
created_at, updated_at
|
||||
created_at, updated_at, pre_hook, post_hook
|
||||
FROM source_groups WHERE host_id = ? AND name = ?`,
|
||||
hostID, name)
|
||||
g, err := scanSourceGroup(row)
|
||||
@@ -177,7 +179,7 @@ func (st *Store) ListSourceGroupsByHost(ctx context.Context, hostID string) ([]S
|
||||
rows, err := st.db.QueryContext(ctx,
|
||||
`SELECT id, host_id, name, includes, excludes, retention_policy,
|
||||
retry_max, retry_backoff_seconds, conflict_dimension,
|
||||
created_at, updated_at
|
||||
created_at, updated_at, pre_hook, post_hook
|
||||
FROM source_groups WHERE host_id = ? ORDER BY name`,
|
||||
hostID)
|
||||
if err != nil {
|
||||
@@ -224,14 +226,17 @@ func scanSourceGroupRow(s sourceGroupScanner) (*SourceGroup, error) {
|
||||
includes, excludes, retention string
|
||||
conflict sql.NullString
|
||||
createdAt, updatedAt string
|
||||
preHook, postHook []byte
|
||||
)
|
||||
err := s.Scan(&out.ID, &out.HostID, &out.Name,
|
||||
&includes, &excludes, &retention,
|
||||
&out.RetryMax, &out.RetryBackoffSeconds, &conflict,
|
||||
&createdAt, &updatedAt)
|
||||
&createdAt, &updatedAt, &preHook, &postHook)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out.PreHook = preHook
|
||||
out.PostHook = postHook
|
||||
if includes != "" {
|
||||
_ = json.Unmarshal([]byte(includes), &out.Includes)
|
||||
}
|
||||
@@ -259,3 +264,13 @@ func nullableString(s string) any {
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// nullableBytes returns nil for an empty/nil slice so SQL stores it
|
||||
// as NULL rather than an empty BLOB. The agent treats both the same
|
||||
// (no hook), but NULL is the canonical "absent" form on disk.
|
||||
func nullableBytes(b []byte) any {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -66,6 +66,12 @@ type Host struct {
|
||||
// (backup, restore, prune). nil = no cap.
|
||||
BandwidthUpKBps *int
|
||||
BandwidthDownKBps *int
|
||||
|
||||
// PreHookDefault / PostHookDefault are AEAD-encrypted host-wide
|
||||
// hook bodies. Per source group hooks (SourceGroup.PreHook /
|
||||
// PostHook) override these when set. nil = no default configured.
|
||||
PreHookDefault []byte
|
||||
PostHookDefault []byte
|
||||
}
|
||||
|
||||
// Schedule is now intentionally slim: cron + which groups + enabled.
|
||||
@@ -106,6 +112,13 @@ type SourceGroup struct {
|
||||
ConflictDimension string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// PreHook / PostHook are AEAD-encrypted shell snippets (raw blob).
|
||||
// nil means "no hook configured." Encryption/decryption happens at
|
||||
// the HTTP layer (where AEAD lives); the store layer just persists
|
||||
// the bytes verbatim.
|
||||
PreHook []byte
|
||||
PostHook []byte
|
||||
}
|
||||
|
||||
// RetentionPolicy is the typed view of `restic forget --keep-*`.
|
||||
|
||||
Reference in New Issue
Block a user