Files
restic-manager/internal/store/hosts.go
T

432 lines
14 KiB
Go

package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
)
// CreateHost inserts a new host row. Used by the enrollment flow.
// The caller has already minted the host id and hashed the agent
// bearer token.
func (s *Store) CreateHost(ctx context.Context, h Host, agentTokenHash, certPinSHA256 string) error {
tags, err := json.Marshal(h.Tags)
if err != nil {
return fmt.Errorf("store: marshal tags: %w", err)
}
_, err = s.db.ExecContext(ctx,
`INSERT INTO hosts (
id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, status, tags,
agent_token_hash, cert_pin_sha256
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'offline', ?, ?, ?)`,
h.ID, h.Name, h.OS, h.Arch,
h.AgentVersion, h.ResticVersion, h.ProtocolVersion,
h.EnrolledAt.UTC().Format(time.RFC3339Nano),
string(tags),
agentTokenHash, certPinSHA256)
if err != nil {
return fmt.Errorf("store: create host: %w", err)
}
return nil
}
// LookupHostByAgentToken resolves a hashed agent bearer token to the
// host it belongs to. Returns ErrNotFound on miss.
func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*Host, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default,
repo_status, repo_status_error, always_on
FROM hosts WHERE agent_token_hash = ?`,
tokenHash)
return scanHost(row)
}
// GetHost returns a host by ID. Returns ErrNotFound on miss.
func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default,
repo_status, repo_status_error, always_on
FROM hosts WHERE id = ?`, id)
return scanHost(row)
}
// SetHostRepoStatus persists the outcome of the latest init / probe
// attempt against this host's repo. Called by the WS handler on every
// job.finished of kind=init, and reset to ("unknown", "") by
// repo-credentials saves so the next probe reflects the new creds.
//
// errMsg is stored verbatim (truncate at the call site if you care
// about row size). Empty for "ready".
func (s *Store) SetHostRepoStatus(ctx context.Context, hostID, status, errMsg string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts SET repo_status = ?, repo_status_error = ? WHERE id = ?`,
status, errMsg, hostID)
if err != nil {
return fmt.Errorf("store: set host repo status: %w", err)
}
return nil
}
// SetHostLastBackup projects a finished backup job onto the host row
// so the dashboard can show last-run state without trawling the jobs
// table. Called from the WS handler on job.finished where kind=backup.
func (s *Store) SetHostLastBackup(ctx context.Context, hostID, status string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts SET last_backup_at = ?, last_backup_status = ? WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), status, hostID)
if err != nil {
return fmt.Errorf("store: set host last backup: %w", err)
}
return nil
}
// DeleteHost removes a host row by id. Returns ErrNotFound if no row
// matched. Foreign-key cascades (declared on every dependent table —
// schedules, jobs, snapshots, source_groups, host_credentials, etc.)
// remove the rest. The connection DSN already pins
// PRAGMA foreign_keys=ON, so the cascade is honoured here without an
// explicit pragma roundtrip.
//
// The host's agent bearer is stored in agent_token_hash on this row,
// so deleting the row also revokes the agent — a re-installed
// instance must come back through the normal pending-host accept
// flow.
func (s *Store) DeleteHost(ctx context.Context, id string) error {
res, err := s.db.ExecContext(ctx, `DELETE FROM hosts WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("store: delete host: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("store: delete host rows: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// MarkHostHello updates the host row with metadata received in the
// agent's hello message and flips status to 'online'.
func (s *Store) MarkHostHello(ctx context.Context, id string, agentVersion, resticVersion string, protoVersion int, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET agent_version = ?, restic_version = ?, protocol_version = ?,
last_seen_at = ?, status = 'online'
WHERE id = ?`,
agentVersion, resticVersion, protoVersion,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: mark hello: %w", err)
}
return nil
}
// TouchHost updates last_seen_at on heartbeat, leaving status alone if
// already online (the offline-marker is a separate sweep).
func (s *Store) TouchHost(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET last_seen_at = ?,
status = CASE WHEN status = 'offline' THEN 'online' ELSE status END
WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: touch host: %w", err)
}
return nil
}
// MarkHostsOfflineStale flips any host that hasn't been seen since
// before `cutoff` from 'online' to 'offline'. Returns the number of
// rows affected so the caller can log non-zero events.
func (s *Store) MarkHostsOfflineStale(ctx context.Context, cutoff time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx,
`UPDATE hosts
SET status = 'offline'
WHERE status = 'online'
AND (last_seen_at IS NULL OR last_seen_at < ?)`,
cutoff.UTC().Format(time.RFC3339Nano))
if err != nil {
return 0, fmt.Errorf("store: mark offline: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
// MarkHostsOfflineStaleReturnIDs flips any host that hasn't been seen
// since before `cutoff` from 'online' to 'offline' and returns the IDs
// of every host that was flipped. Uses a single transaction.
func (s *Store) MarkHostsOfflineStaleReturnIDs(ctx context.Context, cutoff time.Time) ([]string, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("store: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
cutoffStr := cutoff.UTC().Format(time.RFC3339Nano)
rows, err := tx.QueryContext(ctx,
`SELECT id FROM hosts
WHERE status = 'online'
AND (last_seen_at IS NULL OR last_seen_at < ?)`,
cutoffStr)
if err != nil {
return nil, fmt.Errorf("store: select stale hosts: %w", err)
}
var ids []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
_ = rows.Close()
return nil, fmt.Errorf("store: scan stale host id: %w", err)
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("store: iterate stale hosts: %w", err)
}
_ = rows.Close()
if len(ids) > 0 {
if _, err := tx.ExecContext(ctx,
`UPDATE hosts SET status = 'offline'
WHERE status = 'online'
AND (last_seen_at IS NULL OR last_seen_at < ?)`,
cutoffStr); err != nil {
return nil, fmt.Errorf("store: mark offline: %w", err)
}
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("store: commit: %w", err)
}
return ids, nil
}
// ListHosts returns every host. Phase 1 callers fit a small fleet in
// memory; pagination lands when it matters.
func (s *Store) ListHosts(ctx context.Context) ([]Host, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT id, name, os, arch, agent_version, restic_version, protocol_version,
enrolled_at, last_seen_at, status, repo_id, tags,
current_job_id, last_backup_at, last_backup_status,
repo_size_bytes, snapshot_count, open_alert_count,
applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps,
pre_hook_default, post_hook_default,
repo_status, repo_status_error, always_on
FROM hosts ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("store: list hosts: %w", err)
}
defer func() { _ = rows.Close() }()
var out []Host
for rows.Next() {
h, err := scanHostRow(rows)
if err != nil {
return nil, err
}
out = append(out, *h)
}
return out, rows.Err()
}
// ----- scan helpers --------------------------------------------------
type hostScanner interface {
Scan(dest ...any) error
}
func scanHost(row *sql.Row) (*Host, error) {
h, err := scanHostRow(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return h, err
}
func scanHostRow(s hostScanner) (*Host, error) {
var h Host
var (
lastSeen, lastBackupAt sql.NullString
repoID, currentJob, lastBkSt sql.NullString
enrolled string
tags string
bwUp, bwDown sql.NullInt64
preHook, postHook sql.NullString
alwaysOn int
)
err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch,
&h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion,
&enrolled, &lastSeen, &h.Status, &repoID, &tags,
&currentJob, &lastBackupAt, &lastBkSt,
&h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount,
&h.AppliedScheduleVersion, &bwUp, &bwDown,
&preHook, &postHook,
&h.RepoStatus, &h.RepoStatusError, &alwaysOn)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("store: scan host: %w", err)
}
t, err := time.Parse(time.RFC3339Nano, enrolled)
if err != nil {
return nil, fmt.Errorf("store: parse enrolled_at: %w", err)
}
h.EnrolledAt = t
if lastSeen.Valid {
t, err := time.Parse(time.RFC3339Nano, lastSeen.String)
if err != nil {
return nil, fmt.Errorf("store: parse last_seen_at: %w", err)
}
h.LastSeenAt = &t
}
if lastBackupAt.Valid {
t, err := time.Parse(time.RFC3339Nano, lastBackupAt.String)
if err != nil {
return nil, fmt.Errorf("store: parse last_backup_at: %w", err)
}
h.LastBackupAt = &t
}
if repoID.Valid {
s := repoID.String
h.RepoID = &s
}
if currentJob.Valid {
s := currentJob.String
h.CurrentJobID = &s
}
if lastBkSt.Valid {
s := lastBkSt.String
h.LastBackupStatus = &s
}
if tags != "" {
_ = json.Unmarshal([]byte(tags), &h.Tags)
}
if bwUp.Valid {
v := int(bwUp.Int64)
h.BandwidthUpKBps = &v
}
if bwDown.Valid {
v := int(bwDown.Int64)
h.BandwidthDownKBps = &v
}
if preHook.Valid {
h.PreHookDefault = preHook.String
}
if postHook.Valid {
h.PostHookDefault = postHook.String
}
h.AlwaysOn = alwaysOn != 0
return &h, nil
}
// SetHostHooks replaces the host-wide pre/post hook defaults. Pass
// the empty string to clear that hook. Stored verbatim — caller is
// expected to encrypt before they reach this layer.
func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post string) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`,
nullableString(pre), nullableString(post), hostID)
if err != nil {
return fmt.Errorf("store: set host hooks: %w", err)
}
return nil
}
// SetHostBandwidth replaces the host's upload/download caps. Pass nil
// to clear a cap. Caller decides validation; non-positive caps are
// treated as "no cap" by the agent regardless.
func (s *Store) SetHostBandwidth(ctx context.Context, hostID string, upKBps, downKBps *int) error {
_, err := s.db.ExecContext(ctx,
`UPDATE hosts SET bandwidth_up_kbps = ?, bandwidth_down_kbps = ? WHERE id = ?`,
nullableInt(upKBps), nullableInt(downKBps), hostID)
if err != nil {
return fmt.Errorf("store: set host bandwidth: %w", err)
}
return nil
}
// SetHostTags replaces the host's tag list. Tags are passed already
// normalised (lowercase, deduped) by the caller — store-layer just
// JSON-marshals and writes. Empty slice clears all tags.
func (s *Store) SetHostTags(ctx context.Context, hostID string, tags []string) error {
if tags == nil {
tags = []string{}
}
b, err := json.Marshal(tags)
if err != nil {
return fmt.Errorf("store: marshal tags: %w", err)
}
_, err = s.db.ExecContext(ctx,
`UPDATE hosts SET tags = ? WHERE id = ?`, string(b), hostID)
if err != nil {
return fmt.Errorf("store: set host tags: %w", err)
}
return nil
}
// SetHostAlwaysOn flips the host's always-on flag. true = 24x7 server
// (default); false = intermittent host (laptop). See the
// always-on-host-mode spec.
func (s *Store) SetHostAlwaysOn(ctx context.Context, hostID string, alwaysOn bool) error {
v := 0
if alwaysOn {
v = 1
}
res, err := s.db.ExecContext(ctx,
`UPDATE hosts SET always_on = ? WHERE id = ?`, v, hostID)
if err != nil {
return fmt.Errorf("store: set host always_on: %w", err)
}
if n, _ := res.RowsAffected(); n == 0 {
return ErrNotFound
}
return nil
}
// DistinctHostTags returns the union of every tag in use across the
// fleet, sorted. Powers the autocomplete on the host-tags editor and
// the chip-row filter on the dashboard. Cheap at fleet sizes this
// codebase targets — re-query on each render is fine.
func (s *Store) DistinctHostTags(ctx context.Context) ([]string, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT DISTINCT json_each.value
FROM hosts, json_each(hosts.tags)
ORDER BY 1`)
if err != nil {
return nil, fmt.Errorf("store: distinct host tags: %w", err)
}
defer func() { _ = rows.Close() }()
var out []string
for rows.Next() {
var t string
if err := rows.Scan(&t); err != nil {
return nil, err
}
out = append(out, t)
}
return out, rows.Err()
}
func nullableInt(p *int) any {
if p == nil {
return nil
}
return *p
}