store+server: P2-18a announce-and-approve schema + endpoint
migration 0011 adds pending_hosts table (id, hostname, public_key,
fingerprint, expiry). store/pending_hosts.go covers full CRUD plus
hostname-collision count + expired-row sweeper.
POST /api/agents/announce takes {hostname, os, arch, agent_version,
restic_version, public_key (base64)}, returns {pending_id,
fingerprint, hostname_collision}. Per-source-IP token-bucket
rate limit (10/min) + global cap of 100 in-flight rows. Public
key must be exactly 32 bytes (Ed25519).
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
-- 0011_pending_hosts.sql
|
||||
--
|
||||
-- P2-18: announce-and-approve enrolment.
|
||||
--
|
||||
-- Agents that don't have an enrolment token announce themselves
|
||||
-- with `POST /api/agents/announce`, persisting one row here. The
|
||||
-- admin sees them in the dashboard's Pending hosts panel and can
|
||||
-- accept (mints a real Host row + bearer) or reject (deletes the
|
||||
-- row + closes the agent's pending WS).
|
||||
--
|
||||
-- public_key is the agent's Ed25519 public key (32 raw bytes).
|
||||
-- fingerprint = "SHA256:" + hex(sha256(public_key)) — printed by
|
||||
-- the install script on the endpoint terminal so the operator can
|
||||
-- compare the two before clicking accept. This comparison is the
|
||||
-- load-bearing security gate for this flow.
|
||||
--
|
||||
-- expires_at is set to first_seen_at + 1h on insert; a sweeper
|
||||
-- goroutine (P2-18b) deletes rows past their expiry. Hostname
|
||||
-- collisions with existing or other pending rows are *not*
|
||||
-- prevented at the DB level — multiple announces with the same
|
||||
-- hostname are flagged in the UI so admin can pick the right one.
|
||||
|
||||
CREATE TABLE pending_hosts (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT NOT NULL,
|
||||
os TEXT NOT NULL,
|
||||
arch TEXT NOT NULL,
|
||||
agent_version TEXT NOT NULL,
|
||||
restic_version TEXT NOT NULL,
|
||||
public_key BLOB NOT NULL, -- 32-byte Ed25519
|
||||
fingerprint TEXT NOT NULL, -- "SHA256:hex(...)"
|
||||
announced_from_ip TEXT NOT NULL,
|
||||
first_seen_at TEXT NOT NULL,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX pending_hosts_expires ON pending_hosts(expires_at);
|
||||
CREATE INDEX pending_hosts_fingerprint ON pending_hosts(fingerprint);
|
||||
CREATE INDEX pending_hosts_hostname ON pending_hosts(hostname);
|
||||
@@ -0,0 +1,225 @@
|
||||
// pending_hosts.go — store layer for the announce-and-approve
|
||||
// enrolment queue (P2-18a). Rows live for at most 1h; a sweeper
|
||||
// deletes anything past expires_at.
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PendingHost mirrors the pending_hosts table row, plus the derived
|
||||
// HostnameCollision flag the API hands back to the agent so the
|
||||
// install script can warn the operator at announce time.
|
||||
type PendingHost struct {
|
||||
ID string
|
||||
Hostname string
|
||||
OS string
|
||||
Arch string
|
||||
AgentVersion string
|
||||
ResticVersion string
|
||||
PublicKey []byte // 32-byte Ed25519
|
||||
Fingerprint string // "SHA256:hex"
|
||||
AnnouncedFromIP string
|
||||
FirstSeenAt time.Time
|
||||
LastSeenAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// FingerprintForKey returns the canonical "SHA256:hex" fingerprint
|
||||
// the operator sees in the UI and on the endpoint terminal.
|
||||
func FingerprintForKey(pubKey []byte) string {
|
||||
sum := sha256.Sum256(pubKey)
|
||||
return "SHA256:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// CreatePendingHost inserts a new row. Caller has already validated
|
||||
// the public key length and rate limits.
|
||||
func (s *Store) CreatePendingHost(ctx context.Context, ph *PendingHost) error {
|
||||
if ph.ID == "" || len(ph.PublicKey) == 0 {
|
||||
return errors.New("store: pending host id + public_key required")
|
||||
}
|
||||
if ph.Fingerprint == "" {
|
||||
ph.Fingerprint = FingerprintForKey(ph.PublicKey)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if ph.FirstSeenAt.IsZero() {
|
||||
ph.FirstSeenAt = now
|
||||
}
|
||||
ph.LastSeenAt = now
|
||||
if ph.ExpiresAt.IsZero() {
|
||||
ph.ExpiresAt = now.Add(time.Hour)
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO pending_hosts (
|
||||
id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
ph.ID, ph.Hostname, ph.OS, ph.Arch, ph.AgentVersion, ph.ResticVersion,
|
||||
ph.PublicKey, ph.Fingerprint, ph.AnnouncedFromIP,
|
||||
ph.FirstSeenAt.Format(time.RFC3339Nano),
|
||||
ph.LastSeenAt.Format(time.RFC3339Nano),
|
||||
ph.ExpiresAt.Format(time.RFC3339Nano),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create pending host: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TouchPendingHost bumps last_seen_at on the named pending row,
|
||||
// extending its visibility in the dashboard while the agent's
|
||||
// pending WS stays open. Does NOT extend expires_at — the 1h cap
|
||||
// is firm.
|
||||
func (s *Store) TouchPendingHost(ctx context.Context, id string, when time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE pending_hosts SET last_seen_at = ? WHERE id = ?`,
|
||||
when.UTC().Format(time.RFC3339Nano), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPendingHost returns one row by ID. ErrNotFound on miss.
|
||||
func (s *Store) GetPendingHost(ctx context.Context, id string) (*PendingHost, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE id = ?`, id)
|
||||
return scanPendingHost(row)
|
||||
}
|
||||
|
||||
// GetPendingHostByFingerprint resolves a row by its public key
|
||||
// fingerprint (used by the WS pending handler to look up which row
|
||||
// an incoming connection corresponds to).
|
||||
func (s *Store) GetPendingHostByFingerprint(ctx context.Context, fp string) (*PendingHost, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE fingerprint = ?`, fp)
|
||||
return scanPendingHost(row)
|
||||
}
|
||||
|
||||
// ListPendingHosts returns every non-expired row, newest first. The
|
||||
// caller passes `now` so tests can fast-forward.
|
||||
func (s *Store) ListPendingHosts(ctx context.Context, now time.Time) ([]PendingHost, error) {
|
||||
rows, err := s.db.QueryContext(ctx,
|
||||
`SELECT id, hostname, os, arch, agent_version, restic_version,
|
||||
public_key, fingerprint, announced_from_ip,
|
||||
first_seen_at, last_seen_at, expires_at
|
||||
FROM pending_hosts WHERE expires_at > ?
|
||||
ORDER BY first_seen_at DESC`,
|
||||
now.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: list pending hosts: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
out := []PendingHost{}
|
||||
for rows.Next() {
|
||||
ph, err := scanPendingHostRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, *ph)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// CountPendingHosts returns the count of non-expired rows. Used for
|
||||
// the global cap (P2-18: refuse new announces past 100 in flight).
|
||||
func (s *Store) CountPendingHosts(ctx context.Context, now time.Time) (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM pending_hosts WHERE expires_at > ?`,
|
||||
now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: count pending hosts: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// CountPendingHostsByHostname returns the number of non-expired
|
||||
// pending rows that share the supplied hostname. Used by the
|
||||
// announce endpoint to set the hostname_collision flag in its
|
||||
// response.
|
||||
func (s *Store) CountPendingHostsByHostname(ctx context.Context, hostname string, now time.Time) (int, error) {
|
||||
var n int
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM pending_hosts WHERE hostname = ? AND expires_at > ?`,
|
||||
hostname, now.UTC().Format(time.RFC3339Nano)).Scan(&n)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: count pending hosts by hostname: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DeletePendingHost removes one row by ID. ErrNotFound on miss.
|
||||
func (s *Store) DeletePendingHost(ctx context.Context, id string) error {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM pending_hosts WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: delete pending host: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredPendingHosts removes every row whose expires_at is in
|
||||
// the past. Returns the number of rows deleted so the sweeper can
|
||||
// log non-zero events.
|
||||
func (s *Store) DeleteExpiredPendingHosts(ctx context.Context, now time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM pending_hosts WHERE expires_at <= ?`,
|
||||
now.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: delete expired pending hosts: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ----- scan helpers --------------------------------------------------
|
||||
|
||||
type pendingHostScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPendingHost(row *sql.Row) (*PendingHost, error) {
|
||||
ph, err := scanPendingHostRow(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return ph, err
|
||||
}
|
||||
|
||||
func scanPendingHostRow(s pendingHostScanner) (*PendingHost, error) {
|
||||
var (
|
||||
ph PendingHost
|
||||
firstSeenAt, lastSeenAt, expiresAt string
|
||||
)
|
||||
if err := s.Scan(&ph.ID, &ph.Hostname, &ph.OS, &ph.Arch,
|
||||
&ph.AgentVersion, &ph.ResticVersion,
|
||||
&ph.PublicKey, &ph.Fingerprint, &ph.AnnouncedFromIP,
|
||||
&firstSeenAt, &lastSeenAt, &expiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, firstSeenAt); err == nil {
|
||||
ph.FirstSeenAt = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, lastSeenAt); err == nil {
|
||||
ph.LastSeenAt = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, expiresAt); err == nil {
|
||||
ph.ExpiresAt = t
|
||||
}
|
||||
return &ph, nil
|
||||
}
|
||||
Reference in New Issue
Block a user