// pending_hosts.go — store layer for the announce-and-approve // enrolment queue (P2-18a). Rows live for at most 1h; a sweeper // deletes anything past expires_at. package store import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "errors" "fmt" "time" ) // PendingHost mirrors the pending_hosts table row, plus the derived // HostnameCollision flag the API hands back to the agent so the // install script can warn the operator at announce time. type PendingHost struct { ID string Hostname string OS string Arch string AgentVersion string ResticVersion string PublicKey []byte // 32-byte Ed25519 Fingerprint string // "SHA256:hex" AnnouncedFromIP string FirstSeenAt time.Time LastSeenAt time.Time ExpiresAt time.Time } // FingerprintForKey returns the canonical "SHA256:hex" fingerprint // the operator sees in the UI and on the endpoint terminal. func FingerprintForKey(pubKey []byte) string { sum := sha256.Sum256(pubKey) return "SHA256:" + hex.EncodeToString(sum[:]) } // CreatePendingHost inserts a new row. Caller has already validated // the public key length and rate limits. func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error { if ph.ID == "" || len(ph.PublicKey) == 0 { return errors.New("store: pending host id + public_key required") } if ph.Fingerprint == "" { ph.Fingerprint = FingerprintForKey(ph.PublicKey) } now := time.Now().UTC() if ph.FirstSeenAt.IsZero() { ph.FirstSeenAt = now } ph.LastSeenAt = now if ph.ExpiresAt.IsZero() { ph.ExpiresAt = now.Add(time.Hour) } _, err := s.db.ExecContext(ctx, `INSERT INTO pending_hosts ( id, hostname, os, arch, agent_version, restic_version, public_key, fingerprint, announced_from_ip, first_seen_at, last_seen_at, expires_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion, ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP, ph.FirstSeenAt.Format(time.RFC3339Nano), ph.LastSeenAt.Format(time.RFC3339Nano), ph.ExpiresAt.Format(time.RFC3339Nano), ) if err != nil { return fmt.Errorf("store: create pending host: %w", err) } return nil } // TouchPendingHost bumps last_seen_at on the named pending row, // extending its visibility in the dashboard while the agent's // pending WS stays open. Does NOT extend expires_at — the 1h cap // is firm. func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error { _, err := s.db.ExecContext(ctx, `UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`, when.UTC().Format(time.RFC3339Nano), id) return err } // GetPendingHost returns one row by ID. ErrNotFound on miss. func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) { row := s.db.QueryRowContext(ctx, `SELECT id, hostname, os, arch, agent_version, restic_version, public_key, fingerprint, announced_from_ip, first_seen_at, last_seen_at, expires_at FROM pending_hosts WHERE id = ?`, id) return scanPendingHost(row) } // GetPendingHostByFingerprint resolves a row by its public key // fingerprint (used by the WS pending handler to look up which row // an incoming connection corresponds to). func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) { row := s.db.QueryRowContext(ctx, `SELECT id, hostname, os, arch, agent_version, restic_version, public_key, fingerprint, announced_from_ip, first_seen_at, last_seen_at, expires_at FROM pending_hosts WHERE fingerprint = ?`, fp) return scanPendingHost(row) } // ListPendingHosts returns every non-expired row, newest first. The // caller passes `now` so tests can fast-forward. func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) { rows, err := s.db.QueryContext(ctx, `SELECT id, hostname, os, arch, agent_version, restic_version, public_key, fingerprint, announced_from_ip, first_seen_at, last_seen_at, expires_at FROM pending_hosts WHERE expires_at > ? ORDER BY first_seen_at DESC`, now.UTC().Format(time.RFC3339Nano)) if err != nil { return nil, fmt.Errorf("store: list pending hosts: %w", err) } defer func() { _ = rows.Close() }() out := []PendingHost{} for rows.Next() { ph, err := scanPendingHostRow(rows) if err != nil { return nil, err } out = append(out, *ph) } return out, rows.Err() } // CountPendingHosts returns the count of non-expired rows. Used for // the global cap (P2-18: refuse new announces past 100 in flight). func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) { var n int err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`, now.UTC().Format(time.RFC3339Nano)).Scan(&n) if err != nil { return 0, fmt.Errorf("store: count pending hosts: %w", err) } return n, nil } // CountPendingHostsByHostname returns the number of non-expired // pending rows that share the supplied hostname. Used by the // announce endpoint to set the hostname_collision flag in its // response. func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) { var n int err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`, hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n) if err != nil { return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err) } return n, nil } // DeletePendingHost removes one row by ID. ErrNotFound on miss. func (s *Store) DeletePendingHost(ctx context.Context, id string) error { res, err := s.db.ExecContext(ctx, `DELETE FROM pending_hosts WHERE id = ?`, id) if err != nil { return fmt.Errorf("store: delete pending host: %w", err) } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } // DeleteExpiredPendingHosts removes every row whose expires_at is in // the past. Returns the number of rows deleted so the sweeper can // log non-zero events. func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) { res, err := s.db.ExecContext(ctx, `DELETE FROM pending_hosts WHERE expires_at <= ?`, now.UTC().Format(time.RFC3339Nano)) if err != nil { return 0, fmt.Errorf("store: delete expired pending hosts: %w", err) } n, _ := res.RowsAffected() return n, nil } // ----- scan helpers -------------------------------------------------- type pendingHostScanner interface { Scan(dest ...any) error } func scanPendingHost(row *sql.Row) (*PendingHost, error) { ph, err := scanPendingHostRow(row) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return ph, err } func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) { var ( ph PendingHost firstSeenAt, lastSeenAt, expiresAt string ) if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch, &ph.AgentVersion, &ph.ResticVersion, &ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP, &firstSeenAt, &lastSeenAt, &expiresAt); err != nil { return nil, err } if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil { ph.FirstSeenAt = t } if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil { ph.LastSeenAt = t } if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil { ph.ExpiresAt = t } return &ph, nil }