Files
steve 7b1990cf11 agent+server: P2R-11 pre/post hook execution for backup jobs
Agent: new runner.BackupHooks struct + runHook helper invoked via
/bin/sh -c (cmd.exe /C on Windows). pre_hook non-zero exit aborts
the backup; post_hook always runs with RM_JOB_STATUS=succeeded|failed
in env. Output streamed as 'hook(<phase>): …' log.stream lines.
Hooks only run for kind=backup (other kinds skip both phases).

Server: resolveBackupHooks resolves group → host default → empty,
decrypts via crypto.AEAD with per-slot ad bytes, plumbs plaintext
into CommandRunPayload for both schedule.fire and per-group
Run-now dispatch sites. Decrypt failures degrade silently to no
hook so a malformed blob can't poison every backup.
2026-05-04 10:57:28 +01:00

271 lines
8.2 KiB
Go

package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
)
// CreateSourceGroup inserts a new group + bumps host_schedule_version
// in one tx. Group name doubles as the snapshot tag on backups; the
// (host_id, name) UNIQUE constraint enforces tag unambiguity.
func (st *Store) CreateSourceGroup(ctx context.Context, g *SourceGroup) error {
if g.ID == "" || g.HostID == "" || g.Name == "" {
return errors.New("store: source group id, host_id, name required")
}
now := time.Now().UTC()
g.CreatedAt = now
g.UpdatedAt = now
if g.Includes == nil {
g.Includes = []string{}
}
if g.Excludes == nil {
g.Excludes = []string{}
}
if g.RetryMax == 0 {
g.RetryMax = 3
}
if g.RetryBackoffSeconds == 0 {
g.RetryBackoffSeconds = 60
}
includesJSON, _ := json.Marshal(g.Includes)
excludesJSON, _ := json.Marshal(g.Excludes)
retentionJSON, _ := json.Marshal(g.RetentionPolicy)
tx, err := st.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("store: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
if _, err := tx.ExecContext(ctx,
`INSERT INTO source_groups (
id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at, pre_hook, post_hook
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
g.ID, g.HostID, g.Name,
string(includesJSON), string(excludesJSON), string(retentionJSON),
g.RetryMax, g.RetryBackoffSeconds,
nullableString(g.ConflictDimension),
now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano),
nullableString(g.PreHook), nullableString(g.PostHook),
); err != nil {
return fmt.Errorf("store: create source group: %w", err)
}
if err := bumpHostScheduleVersionTx(ctx, tx, g.HostID); err != nil {
return err
}
return tx.Commit()
}
// UpdateSourceGroup replaces every editable field on an existing row
// and bumps host_schedule_version. Returns ErrNotFound if no row matched.
func (st *Store) UpdateSourceGroup(ctx context.Context, g *SourceGroup) error {
if g.ID == "" || g.HostID == "" || g.Name == "" {
return errors.New("store: source group id, host_id, name required")
}
if g.Includes == nil {
g.Includes = []string{}
}
if g.Excludes == nil {
g.Excludes = []string{}
}
includesJSON, _ := json.Marshal(g.Includes)
excludesJSON, _ := json.Marshal(g.Excludes)
retentionJSON, _ := json.Marshal(g.RetentionPolicy)
now := time.Now().UTC()
tx, err := st.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("store: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
res, err := tx.ExecContext(ctx,
`UPDATE source_groups SET
name = ?, includes = ?, excludes = ?, retention_policy = ?,
retry_max = ?, retry_backoff_seconds = ?, conflict_dimension = ?,
updated_at = ?, pre_hook = ?, post_hook = ?
WHERE id = ? AND host_id = ?`,
g.Name,
string(includesJSON), string(excludesJSON), string(retentionJSON),
g.RetryMax, g.RetryBackoffSeconds,
nullableString(g.ConflictDimension),
now.Format(time.RFC3339Nano),
nullableString(g.PreHook), nullableString(g.PostHook),
g.ID, g.HostID,
)
if err != nil {
return fmt.Errorf("store: update source group: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
g.UpdatedAt = now
if err := bumpHostScheduleVersionTx(ctx, tx, g.HostID); err != nil {
return err
}
return tx.Commit()
}
// DeleteSourceGroup removes a group and bumps host_schedule_version.
// Junction rows in schedule_source_groups go via ON DELETE CASCADE.
// Caller is expected to have already enforced the "default group
// can't be the only one" UI rule; this layer just deletes.
func (st *Store) DeleteSourceGroup(ctx context.Context, hostID, groupID string) error {
tx, err := st.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("store: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
res, err := tx.ExecContext(ctx,
`DELETE FROM source_groups WHERE id = ? AND host_id = ?`,
groupID, hostID)
if err != nil {
return fmt.Errorf("store: delete source group: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
if err := bumpHostScheduleVersionTx(ctx, tx, hostID); err != nil {
return err
}
return tx.Commit()
}
// GetSourceGroup returns one group by (host_id, id). ErrNotFound on miss.
func (st *Store) GetSourceGroup(ctx context.Context, hostID, groupID string) (*SourceGroup, error) {
row := st.db.QueryRowContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE id = ? AND host_id = ?`,
groupID, hostID)
g, err := scanSourceGroup(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return g, err
}
// GetSourceGroupByName resolves a group by its (host-unique) name.
// Used by retention-conflict detection and the auto-init flow.
func (st *Store) GetSourceGroupByName(ctx context.Context, hostID, name string) (*SourceGroup, error) {
row := st.db.QueryRowContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE host_id = ? AND name = ?`,
hostID, name)
g, err := scanSourceGroup(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return g, err
}
// ListSourceGroupsByHost returns every group for a host, ordered
// by name (so 'default' isn't always at the bottom alphabetically —
// well, it usually IS the only 'd' name on a fresh host so this
// works out fine). Empty slice on miss.
func (st *Store) ListSourceGroupsByHost(ctx context.Context, hostID string) ([]SourceGroup, error) {
rows, err := st.db.QueryContext(ctx,
`SELECT id, host_id, name, includes, excludes, retention_policy,
retry_max, retry_backoff_seconds, conflict_dimension,
created_at, updated_at, pre_hook, post_hook
FROM source_groups WHERE host_id = ? ORDER BY name`,
hostID)
if err != nil {
return nil, fmt.Errorf("store: list source groups: %w", err)
}
defer func() { _ = rows.Close() }()
out := []SourceGroup{}
for rows.Next() {
g, err := scanSourceGroupRow(rows)
if err != nil {
return nil, err
}
out = append(out, *g)
}
return out, rows.Err()
}
// SetSourceGroupConflict updates only the cached conflict_dimension.
// Doesn't bump schedule version (the cache is server-internal, agent
// doesn't see it). Empty string clears the conflict.
func (st *Store) SetSourceGroupConflict(ctx context.Context, groupID, dimension string) error {
_, err := st.db.ExecContext(ctx,
`UPDATE source_groups SET conflict_dimension = ? WHERE id = ?`,
nullableString(dimension), groupID)
if err != nil {
return fmt.Errorf("store: set source group conflict: %w", err)
}
return nil
}
// ----- scan helpers --------------------------------------------------
func scanSourceGroup(row *sql.Row) (*SourceGroup, error) {
return scanSourceGroupRow(row)
}
type sourceGroupScanner interface {
Scan(dest ...any) error
}
func scanSourceGroupRow(s sourceGroupScanner) (*SourceGroup, error) {
var (
out SourceGroup
includes, excludes, retention string
conflict sql.NullString
createdAt, updatedAt string
preHook, postHook sql.NullString
)
err := s.Scan(&out.ID, &out.HostID, &out.Name,
&includes, &excludes, &retention,
&out.RetryMax, &out.RetryBackoffSeconds, &conflict,
&createdAt, &updatedAt, &preHook, &postHook)
if err != nil {
return nil, err
}
if preHook.Valid {
out.PreHook = preHook.String
}
if postHook.Valid {
out.PostHook = postHook.String
}
if includes != "" {
_ = json.Unmarshal([]byte(includes), &out.Includes)
}
if excludes != "" {
_ = json.Unmarshal([]byte(excludes), &out.Excludes)
}
if retention != "" {
_ = json.Unmarshal([]byte(retention), &out.RetentionPolicy)
}
if conflict.Valid {
out.ConflictDimension = conflict.String
}
if t, err := time.Parse(time.RFC3339Nano, createdAt); err == nil {
out.CreatedAt = t
}
if t, err := time.Parse(time.RFC3339Nano, updatedAt); err == nil {
out.UpdatedAt = t
}
return &out, nil
}
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}