Files
steve cd80be3b13 store+server: P2-18a announce-and-approve schema + endpoint
migration 0011 adds pending_hosts table (id, hostname, public_key,
fingerprint, expiry). store/pending_hosts.go covers full CRUD plus
hostname-collision count + expired-row sweeper.

POST /api/agents/announce takes {hostname, os, arch, agent_version,
restic_version, public_key (base64)}, returns {pending_id,
fingerprint, hostname_collision}. Per-source-IP token-bucket
rate limit (10/min) + global cap of 100 in-flight rows. Public
key must be exactly 32 bytes (Ed25519).
2026-05-04 11:03:41 +01:00

226 lines
7.3 KiB
Go

// pending_hosts.go — store layer for the announce-and-approve
// enrolment queue (P2-18a). Rows live for at most 1h; a sweeper
// deletes anything past expires_at.
package store
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"time"
)
// PendingHost mirrors the pending_hosts table row, plus the derived
// HostnameCollision flag the API hands back to the agent so the
// install script can warn the operator at announce time.
type PendingHost struct {
ID string
Hostname string
OS string
Arch string
AgentVersion string
ResticVersion string
PublicKey []byte // 32-byte Ed25519
Fingerprint string // "SHA256:hex"
AnnouncedFromIP string
FirstSeenAt time.Time
LastSeenAt time.Time
ExpiresAt time.Time
}
// FingerprintForKey returns the canonical "SHA256:hex" fingerprint
// the operator sees in the UI and on the endpoint terminal.
func FingerprintForKey(pubKey []byte) string {
sum := sha256.Sum256(pubKey)
return "SHA256:" + hex.EncodeToString(sum[:])
}
// CreatePendingHost inserts a new row. Caller has already validated
// the public key length and rate limits.
func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error {
if ph.ID == "" || len(ph.PublicKey) == 0 {
return errors.New("store: pending host id + public_key required")
}
if ph.Fingerprint == "" {
ph.Fingerprint = FingerprintForKey(ph.PublicKey)
}
now := time.Now().UTC()
if ph.FirstSeenAt.IsZero() {
ph.FirstSeenAt = now
}
ph.LastSeenAt = now
if ph.ExpiresAt.IsZero() {
ph.ExpiresAt = now.Add(time.Hour)
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO pending_hosts (
id, hostname, os, arch, agent_version, restic_version,
public_key, fingerprint, announced_from_ip,
first_seen_at, last_seen_at, expires_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion,
ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP,
ph.FirstSeenAt.Format(time.RFC3339Nano),
ph.LastSeenAt.Format(time.RFC3339Nano),
ph.ExpiresAt.Format(time.RFC3339Nano),
)
if err != nil {
return fmt.Errorf("store: create pending host: %w", err)
}
return nil
}
// TouchPendingHost bumps last_seen_at on the named pending row,
// extending its visibility in the dashboard while the agent's
// pending WS stays open. Does NOT extend expires_at — the 1h cap
// is firm.
func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), id)
return err
}
// GetPendingHost returns one row by ID. ErrNotFound on miss.
func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, hostname, os, arch, agent_version, restic_version,
public_key, fingerprint, announced_from_ip,
first_seen_at, last_seen_at, expires_at
FROM pending_hosts WHERE id = ?`, id)
return scanPendingHost(row)
}
// GetPendingHostByFingerprint resolves a row by its public key
// fingerprint (used by the WS pending handler to look up which row
// an incoming connection corresponds to).
func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, hostname, os, arch, agent_version, restic_version,
public_key, fingerprint, announced_from_ip,
first_seen_at, last_seen_at, expires_at
FROM pending_hosts WHERE fingerprint = ?`, fp)
return scanPendingHost(row)
}
// ListPendingHosts returns every non-expired row, newest first. The
// caller passes `now` so tests can fast-forward.
func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT id, hostname, os, arch, agent_version, restic_version,
public_key, fingerprint, announced_from_ip,
first_seen_at, last_seen_at, expires_at
FROM pending_hosts WHERE expires_at > ?
ORDER BY first_seen_at DESC`,
now.UTC().Format(time.RFC3339Nano))
if err != nil {
return nil, fmt.Errorf("store: list pending hosts: %w", err)
}
defer func() { _ = rows.Close() }()
out := []PendingHost{}
for rows.Next() {
ph, err := scanPendingHostRow(rows)
if err != nil {
return nil, err
}
out = append(out, *ph)
}
return out, rows.Err()
}
// CountPendingHosts returns the count of non-expired rows. Used for
// the global cap (P2-18: refuse new announces past 100 in flight).
func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) {
var n int
err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`,
now.UTC().Format(time.RFC3339Nano)).Scan(&n)
if err != nil {
return 0, fmt.Errorf("store: count pending hosts: %w", err)
}
return n, nil
}
// CountPendingHostsByHostname returns the number of non-expired
// pending rows that share the supplied hostname. Used by the
// announce endpoint to set the hostname_collision flag in its
// response.
func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) {
var n int
err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`,
hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n)
if err != nil {
return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err)
}
return n, nil
}
// DeletePendingHost removes one row by ID. ErrNotFound on miss.
func (s *Store) DeletePendingHost(ctx context.Context, id string) error {
res, err := s.db.ExecContext(ctx,
`DELETE FROM pending_hosts WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("store: delete pending host: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
// DeleteExpiredPendingHosts removes every row whose expires_at is in
// the past. Returns the number of rows deleted so the sweeper can
// log non-zero events.
func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx,
`DELETE FROM pending_hosts WHERE expires_at <= ?`,
now.UTC().Format(time.RFC3339Nano))
if err != nil {
return 0, fmt.Errorf("store: delete expired pending hosts: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
// ----- scan helpers --------------------------------------------------
type pendingHostScanner interface {
Scan(dest ...any) error
}
func scanPendingHost(row *sql.Row) (*PendingHost, error) {
ph, err := scanPendingHostRow(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return ph, err
}
func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) {
var (
ph PendingHost
firstSeenAt, lastSeenAt, expiresAt string
)
if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch,
&ph.AgentVersion, &ph.ResticVersion,
&ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP,
&firstSeenAt, &lastSeenAt, &expiresAt); err != nil {
return nil, err
}
if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil {
ph.FirstSeenAt = t
}
if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil {
ph.LastSeenAt = t
}
if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil {
ph.ExpiresAt = t
}
return &ph, nil
}