package store import ( "context" "database/sql" "encoding/json" "errors" "fmt" "time" ) // CreateHost inserts a new host row. Used by the enrollment flow. // The caller has already minted the host id and hashed the agent // bearer token. func (s *Store) CreateHost(ctx context.Context, h Host, agentTokenHash, certPinSHA256 string) error { tags, err := json.Marshal(h.Tags) if err != nil { return fmt.Errorf("store: marshal tags: %w", err) } _, err = s.db.ExecContext(ctx, `INSERT INTO hosts ( id, name, os, arch, agent_version, restic_version, protocol_version, enrolled_at, status, tags, agent_token_hash, cert_pin_sha256 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'offline', ?, ?, ?)`, h.ID, h.Name, h.OS, h.Arch, h.AgentVersion, h.ResticVersion, h.ProtocolVersion, h.EnrolledAt.UTC().Format(time.RFC3339Nano), string(tags), agentTokenHash, certPinSHA256) if err != nil { return fmt.Errorf("store: create host: %w", err) } return nil } // LookupHostByAgentToken resolves a hashed agent bearer token to the // host it belongs to. Returns ErrNotFound on miss. func (s *Store) LookupHostByAgentToken(ctx context.Context, tokenHash string) (*Host, error) { row := s.db.QueryRowContext(ctx, `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, pre_hook_default, post_hook_default FROM hosts WHERE agent_token_hash = ?`, tokenHash) return scanHost(row) } // GetHost returns a host by ID. Returns ErrNotFound on miss. func (s *Store) GetHost(ctx context.Context, id string) (*Host, error) { row := s.db.QueryRowContext(ctx, `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, pre_hook_default, post_hook_default FROM hosts WHERE id = ?`, id) return scanHost(row) } // MarkHostHello updates the host row with metadata received in the // agent's hello message and flips status to 'online'. func (s *Store) MarkHostHello(ctx context.Context, id string, agentVersion, resticVersion string, protoVersion int, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE hosts SET agent_version = ?, restic_version = ?, protocol_version = ?, last_seen_at = ?, status = 'online' WHERE id = ?`, agentVersion, resticVersion, protoVersion, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: mark hello: %w", err) } return nil } // TouchHost updates last_seen_at on heartbeat, leaving status alone if // already online (the offline-marker is a separate sweep). func (s *Store) TouchHost(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE hosts SET last_seen_at = ?, status = CASE WHEN status = 'offline' THEN 'online' ELSE status END WHERE id = ?`, when.UTC().Format(time.RFC3339Nano), id) if err != nil { return fmt.Errorf("store: touch host: %w", err) } return nil } // MarkHostsOfflineStale flips any host that hasn't been seen since // before `cutoff` from 'online' to 'offline'. Returns the number of // rows affected so the caller can log non-zero events. func (s *Store) MarkHostsOfflineStale(ctx context.Context, cutoff time.Time) (int64, error) { res, err := s.db.ExecContext(ctx, `UPDATE hosts SET status = 'offline' WHERE status = 'online' AND (last_seen_at IS NULL OR last_seen_at < ?)`, cutoff.UTC().Format(time.RFC3339Nano)) if err != nil { return 0, fmt.Errorf("store: mark offline: %w", err) } n, _ := res.RowsAffected() return n, nil } // MarkHostsOfflineStaleReturnIDs flips any host that hasn't been seen // since before `cutoff` from 'online' to 'offline' and returns the IDs // of every host that was flipped. Uses a single transaction. func (s *Store) MarkHostsOfflineStaleReturnIDs(ctx context.Context, cutoff time.Time) ([]string, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("store: begin tx: %w", err) } defer func() { _ = tx.Rollback() }() cutoffStr := cutoff.UTC().Format(time.RFC3339Nano) rows, err := tx.QueryContext(ctx, `SELECT id FROM hosts WHERE status = 'online' AND (last_seen_at IS NULL OR last_seen_at < ?)`, cutoffStr) if err != nil { return nil, fmt.Errorf("store: select stale hosts: %w", err) } var ids []string for rows.Next() { var id string if err := rows.Scan(&id); err != nil { _ = rows.Close() return nil, fmt.Errorf("store: scan stale host id: %w", err) } ids = append(ids, id) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("store: iterate stale hosts: %w", err) } _ = rows.Close() if len(ids) > 0 { if _, err := tx.ExecContext(ctx, `UPDATE hosts SET status = 'offline' WHERE status = 'online' AND (last_seen_at IS NULL OR last_seen_at < ?)`, cutoffStr); err != nil { return nil, fmt.Errorf("store: mark offline: %w", err) } } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("store: commit: %w", err) } return ids, nil } // ListHosts returns every host. Phase 1 callers fit a small fleet in // memory; pagination lands when it matters. func (s *Store) ListHosts(ctx context.Context) ([]Host, error) { rows, err := s.db.QueryContext(ctx, `SELECT id, name, os, arch, agent_version, restic_version, protocol_version, enrolled_at, last_seen_at, status, repo_id, tags, current_job_id, last_backup_at, last_backup_status, repo_size_bytes, snapshot_count, open_alert_count, applied_schedule_version, bandwidth_up_kbps, bandwidth_down_kbps, pre_hook_default, post_hook_default FROM hosts ORDER BY name`) if err != nil { return nil, fmt.Errorf("store: list hosts: %w", err) } defer func() { _ = rows.Close() }() var out []Host for rows.Next() { h, err := scanHostRow(rows) if err != nil { return nil, err } out = append(out, *h) } return out, rows.Err() } // ----- scan helpers -------------------------------------------------- type hostScanner interface { Scan(dest ...any) error } func scanHost(row *sql.Row) (*Host, error) { h, err := scanHostRow(row) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return h, err } func scanHostRow(s hostScanner) (*Host, error) { var h Host var ( lastSeen, lastBackupAt sql.NullString repoID, currentJob, lastBkSt sql.NullString enrolled string tags string bwUp, bwDown sql.NullInt64 preHook, postHook sql.NullString ) err := s.Scan(&h.ID, &h.Name, &h.OS, &h.Arch, &h.AgentVersion, &h.ResticVersion, &h.ProtocolVersion, &enrolled, &lastSeen, &h.Status, &repoID, &tags, ¤tJob, &lastBackupAt, &lastBkSt, &h.RepoSizeBytes, &h.SnapshotCount, &h.OpenAlertCount, &h.AppliedScheduleVersion, &bwUp, &bwDown, &preHook, &postHook) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("store: scan host: %w", err) } t, err := time.Parse(time.RFC3339Nano, enrolled) if err != nil { return nil, fmt.Errorf("store: parse enrolled_at: %w", err) } h.EnrolledAt = t if lastSeen.Valid { t, err := time.Parse(time.RFC3339Nano, lastSeen.String) if err != nil { return nil, fmt.Errorf("store: parse last_seen_at: %w", err) } h.LastSeenAt = &t } if lastBackupAt.Valid { t, err := time.Parse(time.RFC3339Nano, lastBackupAt.String) if err != nil { return nil, fmt.Errorf("store: parse last_backup_at: %w", err) } h.LastBackupAt = &t } if repoID.Valid { s := repoID.String h.RepoID = &s } if currentJob.Valid { s := currentJob.String h.CurrentJobID = &s } if lastBkSt.Valid { s := lastBkSt.String h.LastBackupStatus = &s } if tags != "" { _ = json.Unmarshal([]byte(tags), &h.Tags) } if bwUp.Valid { v := int(bwUp.Int64) h.BandwidthUpKBps = &v } if bwDown.Valid { v := int(bwDown.Int64) h.BandwidthDownKBps = &v } if preHook.Valid { h.PreHookDefault = preHook.String } if postHook.Valid { h.PostHookDefault = postHook.String } return &h, nil } // SetHostHooks replaces the host-wide pre/post hook defaults. Pass // the empty string to clear that hook. Stored verbatim — caller is // expected to encrypt before they reach this layer. func (s *Store) SetHostHooks(ctx context.Context, hostID string, pre, post string) error { _, err := s.db.ExecContext(ctx, `UPDATE hosts SET pre_hook_default = ?, post_hook_default = ? WHERE id = ?`, nullableString(pre), nullableString(post), hostID) if err != nil { return fmt.Errorf("store: set host hooks: %w", err) } return nil } // SetHostBandwidth replaces the host's upload/download caps. Pass nil // to clear a cap. Caller decides validation; non-positive caps are // treated as "no cap" by the agent regardless. func (s *Store) SetHostBandwidth(ctx context.Context, hostID string, upKBps, downKBps *int) error { _, err := s.db.ExecContext(ctx, `UPDATE hosts SET bandwidth_up_kbps = ?, bandwidth_down_kbps = ? WHERE id = ?`, nullableInt(upKBps), nullableInt(downKBps), hostID) if err != nil { return fmt.Errorf("store: set host bandwidth: %w", err) } return nil } func nullableInt(p *int) any { if p == nil { return nil } return *p }