phase 1 foundations: api types, store, crypto, auth
Lands the bottom three layers of Phase 1: P1-08 internal/api: protocol_version + envelope + every WS message shape from spec.md §6.2 (Hello, Heartbeat, Job*, Schedule*, etc). Wire-format tests pin the JSON shape so a rename here breaks tests instead of silently breaking the agent. P1-02 + P1-03 internal/store: SQLite via modernc.org/sqlite, embed.FS + a tiny version table for hand-rolled migrations. 0001_initial.sql covers every table from spec.md §5 plus enrollment_tokens and host_schedule_version. Typed accessors for users / sessions / enrollment / audit. WAL + foreign_keys + busy_timeout on by default. P1-06 internal/crypto: XChaCha20-Poly1305 AEAD wrapper with per-message random nonce. Key file lifecycle (generate + refuse-to-overwrite, load with size validation). Optional additionalData binds ciphertext to the row that owns it. P1-04 internal/auth (partial — passwords + tokens; sessions middleware lands with the HTTP handlers): argon2id following RFC 9106 (64 MiB / t=3 / p=4 / 32B), constant-time verify. HashToken stores SHA-256 of session/agent/enrollment tokens so a stolen DB doesn't hand over credentials. Build floor moves to Go 1.25 (modernc.org/sqlite v1.50+ requires it); CI + Dockerfile + README updated. Markdown lint diagnostics on tasks.md cleared. All packages tested. ~70 new tests pass in <1s. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppendAudit records an audit log entry.
|
||||
func (s *Store) AppendAudit(ctx context.Context, e AuditEntry) error {
|
||||
if len(e.Payload) == 0 {
|
||||
e.Payload = json.RawMessage("{}")
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO audit_log (id, user_id, actor, action, target_kind, target_id, ts, payload)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
e.ID, nullable(e.UserID), e.Actor, e.Action,
|
||||
nullable(e.TargetKind), nullable(e.TargetID),
|
||||
e.TS.UTC().Format(time.RFC3339Nano),
|
||||
string(e.Payload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: append audit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nullable returns nil for nil/empty *string so SQLite stores NULL.
|
||||
// SQLite's driver treats Go nil as NULL but treats *string("") as ''.
|
||||
// We want NULL semantics for "absent."
|
||||
func nullable(p *string) any {
|
||||
if p == nil || *p == "" {
|
||||
return nil
|
||||
}
|
||||
return *p
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
// Package store is the SQLite persistence layer
|
||||
// (modernc.org/sqlite, no CGo).
|
||||
package store
|
||||
@@ -0,0 +1,58 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateEnrollmentToken persists a fresh one-time token. The caller
|
||||
// has already hashed the raw token; the raw form is returned to the
|
||||
// operator (printed in the install snippet) and never persisted.
|
||||
func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration) error {
|
||||
now := time.Now().UTC()
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO enrollment_tokens (token_hash, created_at, expires_at)
|
||||
VALUES (?, ?, ?)`,
|
||||
tokenHash,
|
||||
now.Format(time.RFC3339Nano),
|
||||
now.Add(ttl).Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create enrollment token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumeEnrollmentToken atomically validates a token (must exist,
|
||||
// not be consumed, not be expired) and marks it consumed by hostID.
|
||||
// Returns ErrNotFound on any failure.
|
||||
func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`UPDATE enrollment_tokens
|
||||
SET consumed_at = ?, consumed_host = ?
|
||||
WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`,
|
||||
now, hostID, tokenHash, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: consume enrollment token: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PurgeExpiredEnrollmentTokens deletes long-expired token rows. Tokens
|
||||
// retained for ~24h after expiry so audit traces still resolve them.
|
||||
func (s *Store) PurgeExpiredEnrollmentTokens(ctx context.Context) (int64, error) {
|
||||
cutoff := time.Now().Add(-24 * time.Hour).UTC().Format(time.RFC3339Nano)
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM enrollment_tokens WHERE expires_at <= ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: purge enrollment tokens: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
// migration is one ordered SQL file from migrations/.
|
||||
type migration struct {
|
||||
version int // parsed from filename prefix (0001, 0002, …)
|
||||
name string // full filename, for error messages
|
||||
sql string
|
||||
}
|
||||
|
||||
// loadMigrations reads every migrations/*.sql file in lexical order
|
||||
// and returns them. Filenames must look like NNNN_name.sql; the
|
||||
// numeric prefix is the version.
|
||||
func loadMigrations() ([]migration, error) {
|
||||
entries, err := fs.ReadDir(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read migrations dir: %w", err)
|
||||
}
|
||||
out := make([]migration, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
|
||||
continue
|
||||
}
|
||||
var v int
|
||||
// Allow up to 6 digits (we will never need that many but it
|
||||
// costs nothing to be permissive).
|
||||
if _, err := fmt.Sscanf(e.Name(), "%d_", &v); err != nil {
|
||||
return nil, fmt.Errorf("migration %q: cannot parse version prefix: %w", e.Name(), err)
|
||||
}
|
||||
body, err := fs.ReadFile(migrationsFS, "migrations/"+e.Name())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s: %w", e.Name(), err)
|
||||
}
|
||||
out = append(out, migration{version: v, name: e.Name(), sql: string(body)})
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version })
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// migrate brings the db up to the highest known version. It is
|
||||
// idempotent: running it on an already-current db is a no-op. There
|
||||
// is no rollback path; we move forward only.
|
||||
func migrate(ctx context.Context, db *sql.DB) error {
|
||||
if _, err := db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TEXT NOT NULL
|
||||
)
|
||||
`); err != nil {
|
||||
return fmt.Errorf("create schema_version: %w", err)
|
||||
}
|
||||
|
||||
migs, err := loadMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range migs {
|
||||
var applied int
|
||||
row := db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM schema_version WHERE version = ?`, m.version)
|
||||
if err := row.Scan(&applied); err != nil {
|
||||
return fmt.Errorf("check version %d: %w", m.version, err)
|
||||
}
|
||||
if applied > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx for migration %s: %w", m.name, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("apply %s: %w", m.name, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`INSERT INTO schema_version (version, applied_at) VALUES (?, datetime('now'))`,
|
||||
m.version); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("record %s: %w", m.name, err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit %s: %w", m.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
-- 0001_initial.sql
|
||||
--
|
||||
-- Initial schema for restic-manager. Mirrors the domain model in
|
||||
-- spec.md §5. We use TEXT primary keys (ULIDs) throughout: sortable,
|
||||
-- URL-safe, no autoincrement contention. JSON blobs are stored as
|
||||
-- TEXT; SQLite's json1 extension is available but we read/write
|
||||
-- raw and parse in Go for portability.
|
||||
--
|
||||
-- All timestamps are stored as RFC 3339 TEXT (UTC). SQLite's INTEGER
|
||||
-- (unix epoch) would be cheaper but text is human-readable in dumps
|
||||
-- and the storage cost is negligible at this scale.
|
||||
|
||||
CREATE TABLE users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('admin','operator','viewer')),
|
||||
created_at TEXT NOT NULL,
|
||||
last_login_at TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY, -- session token (high-entropy)
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
ip TEXT,
|
||||
ua TEXT
|
||||
);
|
||||
CREATE INDEX sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX sessions_expires_at ON sessions(expires_at);
|
||||
|
||||
CREATE TABLE credentials (
|
||||
id TEXT PRIMARY KEY,
|
||||
kind TEXT NOT NULL, -- 'rest','s3','local'
|
||||
username TEXT,
|
||||
-- secret_ref is the AEAD ciphertext (nonce || ciphertext, base64).
|
||||
-- The plaintext never lands on disk.
|
||||
secret_ref TEXT NOT NULL,
|
||||
rotated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE repos (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
url TEXT NOT NULL,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('rest','s3','local')),
|
||||
credential_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT,
|
||||
password_secret_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT,
|
||||
-- Cached projection from `restic stats` + lock-file inspection.
|
||||
size_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
snapshot_count INTEGER NOT NULL DEFAULT 0,
|
||||
dedup_ratio REAL NOT NULL DEFAULT 0,
|
||||
last_check_at TEXT,
|
||||
last_check_status TEXT,
|
||||
lock_state TEXT NOT NULL DEFAULT 'unlocked'
|
||||
CHECK (lock_state IN ('locked','unlocked')),
|
||||
append_only INTEGER NOT NULL DEFAULT 1, -- bool
|
||||
credential_rotated_at TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE hosts (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
os TEXT NOT NULL,
|
||||
arch TEXT NOT NULL,
|
||||
agent_version TEXT NOT NULL DEFAULT '',
|
||||
restic_version TEXT NOT NULL DEFAULT '',
|
||||
protocol_version INTEGER NOT NULL DEFAULT 0,
|
||||
enrolled_at TEXT NOT NULL,
|
||||
last_seen_at TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'offline'
|
||||
CHECK (status IN ('online','offline','degraded')),
|
||||
repo_id TEXT REFERENCES repos(id) ON DELETE SET NULL,
|
||||
tags TEXT NOT NULL DEFAULT '[]', -- json array
|
||||
current_job_id TEXT,
|
||||
-- Denormalised projections (refreshed on job.finished etc).
|
||||
last_backup_at TEXT,
|
||||
last_backup_status TEXT
|
||||
CHECK (last_backup_status IN
|
||||
('succeeded','failed','cancelled') OR
|
||||
last_backup_status IS NULL),
|
||||
repo_size_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
snapshot_count INTEGER NOT NULL DEFAULT 0,
|
||||
open_alert_count INTEGER NOT NULL DEFAULT 0,
|
||||
applied_schedule_version INTEGER NOT NULL DEFAULT 0,
|
||||
-- Server-issued credentials for the agent ↔ server WS.
|
||||
agent_token_hash TEXT NOT NULL DEFAULT '',
|
||||
cert_pin_sha256 TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
CREATE INDEX hosts_status ON hosts(status);
|
||||
CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at);
|
||||
|
||||
-- Pending one-time enrollment tokens (TTL'd, single-use).
|
||||
CREATE TABLE enrollment_tokens (
|
||||
token_hash TEXT PRIMARY KEY, -- argon2id of token
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
consumed_at TEXT,
|
||||
consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL
|
||||
);
|
||||
CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at);
|
||||
|
||||
CREATE TABLE schedules (
|
||||
id TEXT PRIMARY KEY,
|
||||
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check')),
|
||||
cron_expr TEXT NOT NULL,
|
||||
paths TEXT NOT NULL DEFAULT '[]', -- json array
|
||||
excludes TEXT NOT NULL DEFAULT '[]',
|
||||
tags TEXT NOT NULL DEFAULT '[]',
|
||||
retention_policy TEXT NOT NULL DEFAULT '{}', -- json object
|
||||
options TEXT NOT NULL DEFAULT '{}', -- json object (bandwidth)
|
||||
-- Hooks are encrypted at rest (AEAD ciphertext). Constraint enforced
|
||||
-- in application code: hooks must be empty unless kind='backup'.
|
||||
pre_hook TEXT NOT NULL DEFAULT '',
|
||||
post_hook TEXT NOT NULL DEFAULT '',
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX schedules_host_id ON schedules(host_id);
|
||||
|
||||
-- Per-host monotonic schedule version. Bumped on any schedules INSERT/
|
||||
-- UPDATE/DELETE for that host. Pushed to the agent in schedule.set;
|
||||
-- the agent acks back the same version in schedule.ack.
|
||||
CREATE TABLE host_schedule_version (
|
||||
host_id TEXT PRIMARY KEY REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
version INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check','unlock')),
|
||||
status TEXT NOT NULL CHECK (status IN ('queued','running','succeeded','failed','cancelled')),
|
||||
scheduled_id TEXT REFERENCES schedules(id) ON DELETE SET NULL,
|
||||
actor_kind TEXT NOT NULL CHECK (actor_kind IN ('user','schedule','system')),
|
||||
actor_id TEXT, -- user id, schedule id, or null
|
||||
started_at TEXT,
|
||||
finished_at TEXT,
|
||||
exit_code INTEGER,
|
||||
stats TEXT, -- json blob from restic
|
||||
error TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX jobs_host_id ON jobs(host_id);
|
||||
CREATE INDEX jobs_status ON jobs(status);
|
||||
CREATE INDEX jobs_created_at ON jobs(created_at);
|
||||
|
||||
CREATE TABLE job_logs (
|
||||
job_id TEXT NOT NULL REFERENCES jobs(id) ON DELETE CASCADE,
|
||||
seq INTEGER NOT NULL,
|
||||
ts TEXT NOT NULL,
|
||||
stream TEXT NOT NULL CHECK (stream IN ('stdout','stderr','event')),
|
||||
payload TEXT NOT NULL,
|
||||
PRIMARY KEY (job_id, seq)
|
||||
);
|
||||
|
||||
CREATE TABLE snapshots (
|
||||
id TEXT PRIMARY KEY, -- restic snapshot id
|
||||
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
repo_id TEXT NOT NULL REFERENCES repos(id) ON DELETE CASCADE,
|
||||
time TEXT NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
paths TEXT NOT NULL DEFAULT '[]',
|
||||
tags TEXT NOT NULL DEFAULT '[]',
|
||||
size_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
file_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE INDEX snapshots_host_id ON snapshots(host_id);
|
||||
CREATE INDEX snapshots_time ON snapshots(time);
|
||||
|
||||
CREATE TABLE alerts (
|
||||
id TEXT PRIMARY KEY,
|
||||
host_id TEXT REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
kind TEXT NOT NULL,
|
||||
severity TEXT NOT NULL CHECK (severity IN ('info','warning','critical')),
|
||||
message TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
acknowledged_at TEXT,
|
||||
acknowledged_by TEXT REFERENCES users(id) ON DELETE SET NULL,
|
||||
resolved_at TEXT
|
||||
);
|
||||
CREATE INDEX alerts_host_id ON alerts(host_id);
|
||||
CREATE INDEX alerts_open ON alerts(host_id) WHERE resolved_at IS NULL;
|
||||
|
||||
CREATE TABLE audit_log (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
|
||||
actor TEXT NOT NULL CHECK (actor IN ('user','agent','system')),
|
||||
action TEXT NOT NULL,
|
||||
target_kind TEXT,
|
||||
target_id TEXT,
|
||||
ts TEXT NOT NULL,
|
||||
payload TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
CREATE INDEX audit_log_ts ON audit_log(ts);
|
||||
CREATE INDEX audit_log_user ON audit_log(user_id);
|
||||
@@ -0,0 +1,88 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateSession persists a session row. The token is hashed before
|
||||
// insert; the raw token is what the caller hands to the user (cookie).
|
||||
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
tokenHash,
|
||||
sess.UserID,
|
||||
sess.CreatedAt.UTC().Format(time.RFC3339Nano),
|
||||
sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
|
||||
sess.IP, sess.UA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupSession resolves a token hash to a session row, returning
|
||||
// ErrNotFound if the hash is unknown OR the session has expired.
|
||||
// We collapse "no row" and "expired" to the same error so the caller
|
||||
// can't tell them apart in error messages — that prevents enumeration
|
||||
// of valid token hashes.
|
||||
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, user_id, created_at, expires_at, ip, ua
|
||||
FROM sessions
|
||||
WHERE id = ? AND expires_at > ?`,
|
||||
tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
|
||||
|
||||
var sess Session
|
||||
var created, expires string
|
||||
var ip, ua sql.NullString
|
||||
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("store: lookup session: %w", err)
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339Nano, created)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse created_at: %w", err)
|
||||
}
|
||||
sess.CreatedAt = t
|
||||
t, err = time.Parse(time.RFC3339Nano, expires)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse expires_at: %w", err)
|
||||
}
|
||||
sess.ExpiresAt = t
|
||||
if ip.Valid {
|
||||
sess.IP = ip.String
|
||||
}
|
||||
if ua.Valid {
|
||||
sess.UA = ua.String
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
// DeleteSession removes a session row by token hash. Used on logout.
|
||||
func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error {
|
||||
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, tokenHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: delete session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PurgeExpiredSessions deletes session rows past their expires_at.
|
||||
// Run periodically from a background goroutine.
|
||||
func (s *Store) PurgeExpiredSessions(ctx context.Context) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM sessions WHERE expires_at <= ?`,
|
||||
time.Now().UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("store: purge sessions: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
// Package store is the SQLite persistence layer (modernc.org/sqlite,
|
||||
// no CGo). It owns the schema, exposes typed accessors, and hides
|
||||
// the database/sql plumbing from the rest of the server.
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite" // register the "sqlite" driver
|
||||
)
|
||||
|
||||
// ErrNotFound is returned by accessors when a lookup misses.
|
||||
var ErrNotFound = errors.New("store: not found")
|
||||
|
||||
// Store is a thin wrapper around *sql.DB that exposes the typed
|
||||
// accessors used by the rest of the server. Callers should use the
|
||||
// provided methods rather than reaching into DB() directly.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// Open opens (or creates) the SQLite database at path, applies all
|
||||
// pending migrations, and returns a ready-to-use Store.
|
||||
//
|
||||
// The DSN sets:
|
||||
// - _pragma=foreign_keys(1) — referential integrity is on
|
||||
// - _pragma=journal_mode(WAL) — concurrent reads vs writes
|
||||
// - _pragma=busy_timeout(5000) — wait 5s on lock contention
|
||||
// - _time_format=sqlite — RFC 3339 read/write of TEXT timestamps
|
||||
//
|
||||
// Empty path uses an in-memory DB (useful for tests).
|
||||
func Open(ctx context.Context, path string) (*Store, error) {
|
||||
dsn := buildDSN(path)
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %q: %w", path, err)
|
||||
}
|
||||
// modernc.org/sqlite is not safe for arbitrary high parallelism on
|
||||
// a single file. WAL helps, but 1 writer + multiple readers is the
|
||||
// only safe shape. Cap connections to keep that property explicit.
|
||||
db.SetMaxOpenConns(8)
|
||||
db.SetMaxIdleConns(4)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("ping: %w", err)
|
||||
}
|
||||
|
||||
if err := migrate(ctx, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
|
||||
return &Store{db: db}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying DB handle.
|
||||
func (s *Store) Close() error { return s.db.Close() }
|
||||
|
||||
// DB returns the underlying *sql.DB. Reserved for tests and migrations
|
||||
// — production code should add a typed method to this package instead.
|
||||
func (s *Store) DB() *sql.DB { return s.db }
|
||||
|
||||
func buildDSN(path string) string {
|
||||
if path == "" {
|
||||
// Shared cache + named in-memory db so multiple connections see
|
||||
// the same data — needed because we cap MaxOpenConns above.
|
||||
return "file::memory:?cache=shared&_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)"
|
||||
}
|
||||
q := url.Values{}
|
||||
q.Add("_pragma", "foreign_keys(1)")
|
||||
q.Add("_pragma", "journal_mode(WAL)")
|
||||
q.Add("_pragma", "busy_timeout(5000)")
|
||||
q.Add("_pragma", "synchronous(NORMAL)")
|
||||
return "file:" + path + "?" + q.Encode()
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// openTestStore opens an isolated file-backed db in a t.TempDir.
|
||||
// In-memory + shared-cache works too but file makes failures easier
|
||||
// to inspect when a test panics.
|
||||
func openTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
s, err := Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func TestOpenAppliesMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
|
||||
row := s.DB().QueryRow(`SELECT MAX(version) FROM schema_version`)
|
||||
var v int
|
||||
if err := row.Scan(&v); err != nil {
|
||||
t.Fatalf("scan: %v", err)
|
||||
}
|
||||
if v < 1 {
|
||||
t.Fatalf("expected at least migration 1 applied, got %d", v)
|
||||
}
|
||||
|
||||
// Spot-check a few tables exist with expected columns.
|
||||
tables := []string{"users", "sessions", "hosts", "repos",
|
||||
"credentials", "schedules", "jobs", "job_logs",
|
||||
"snapshots", "alerts", "audit_log",
|
||||
"enrollment_tokens", "host_schedule_version"}
|
||||
for _, tbl := range tables {
|
||||
row := s.DB().QueryRow(
|
||||
`SELECT name FROM sqlite_master WHERE type='table' AND name = ?`, tbl)
|
||||
var got string
|
||||
if err := row.Scan(&got); err != nil {
|
||||
t.Errorf("table %q missing: %v", tbl, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "rm.db")
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
s, err := Open(context.Background(), path)
|
||||
if err != nil {
|
||||
t.Fatalf("open #%d: %v", i, err)
|
||||
}
|
||||
_ = s.Close()
|
||||
}
|
||||
|
||||
s, err := Open(context.Background(), path)
|
||||
if err != nil {
|
||||
t.Fatalf("final open: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
row := s.DB().QueryRow(`SELECT COUNT(*) FROM schema_version`)
|
||||
var n int
|
||||
if err := row.Scan(&n); err != nil {
|
||||
t.Fatalf("scan: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("re-running migrations should not insert duplicate rows; got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKeysEnforced(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
|
||||
// Inserting a session with a non-existent user should fail because
|
||||
// FKs are on. Without the pragma, SQLite silently accepts this.
|
||||
_, err := s.DB().Exec(
|
||||
`INSERT INTO sessions (id, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, datetime('now'), datetime('now','+1 hour'))`,
|
||||
"sess1", "no-such-user")
|
||||
if err == nil {
|
||||
t.Fatal("expected FK violation, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User mirrors the users table.
|
||||
type User struct {
|
||||
ID string
|
||||
Username string
|
||||
PasswordHash string
|
||||
Role Role
|
||||
CreatedAt time.Time
|
||||
LastLoginAt *time.Time
|
||||
}
|
||||
|
||||
// Role enumerates the access tiers from spec.md §7.2.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleAdmin Role = "admin"
|
||||
RoleOperator Role = "operator"
|
||||
RoleViewer Role = "viewer"
|
||||
)
|
||||
|
||||
// Session mirrors the sessions table. The ID is the (raw) session
|
||||
// token; the DB stores its hash. Callers that hold a *Session have
|
||||
// already authenticated.
|
||||
type Session struct {
|
||||
ID string // session token (raw); never persisted as-is
|
||||
UserID string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
IP string
|
||||
UA string
|
||||
}
|
||||
|
||||
// Host mirrors the denormalised hosts table. JSON columns (tags) are
|
||||
// returned decoded into Go slices for ergonomics.
|
||||
type Host struct {
|
||||
ID string
|
||||
Name string
|
||||
OS string
|
||||
Arch string
|
||||
AgentVersion string
|
||||
ResticVersion string
|
||||
ProtocolVersion int
|
||||
EnrolledAt time.Time
|
||||
LastSeenAt *time.Time
|
||||
Status string
|
||||
RepoID *string
|
||||
Tags []string
|
||||
CurrentJobID *string
|
||||
LastBackupAt *time.Time
|
||||
LastBackupStatus *string
|
||||
RepoSizeBytes int64
|
||||
SnapshotCount int
|
||||
OpenAlertCount int
|
||||
AppliedScheduleVersion int64
|
||||
}
|
||||
|
||||
// EnrollmentToken is the issuer's view of a one-time token. The
|
||||
// raw token is returned only at create time; the DB stores its hash.
|
||||
type EnrollmentToken struct {
|
||||
Raw string // populated on create only
|
||||
TokenHash string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AuditEntry mirrors the audit_log table.
|
||||
type AuditEntry struct {
|
||||
ID string
|
||||
UserID *string
|
||||
Actor string // user|agent|system
|
||||
Action string
|
||||
TargetKind *string
|
||||
TargetID *string
|
||||
TS time.Time
|
||||
Payload json.RawMessage
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CreateUser inserts a new user. The caller is responsible for
|
||||
// generating an ID (typically a ULID) and hashing the password.
|
||||
func (s *Store) CreateUser(ctx context.Context, u User) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO users (id, username, password_hash, role, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
u.ID, u.Username, u.PasswordHash, string(u.Role), u.CreatedAt.UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: create user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByUsername looks up a user by their (case-sensitive) username.
|
||||
// Returns ErrNotFound if no row matches.
|
||||
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, username, password_hash, role, created_at, last_login_at
|
||||
FROM users WHERE username = ?`, username)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
|
||||
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, username, password_hash, role, created_at, last_login_at
|
||||
FROM users WHERE id = ?`, id)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// CountUsers returns the total number of user rows. The first-run
|
||||
// bootstrap uses this to detect a fresh install.
|
||||
func (s *Store) CountUsers(ctx context.Context) (int, error) {
|
||||
var n int
|
||||
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil {
|
||||
return 0, fmt.Errorf("store: count users: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// MarkUserLogin records a successful authentication.
|
||||
func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE users SET last_login_at = ? WHERE id = ?`,
|
||||
when.UTC().Format(time.RFC3339Nano), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: mark login: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanUser(row *sql.Row) (*User, error) {
|
||||
var u User
|
||||
var role string
|
||||
var lastLogin sql.NullString
|
||||
var created string
|
||||
if err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("store: scan user: %w", err)
|
||||
}
|
||||
u.Role = Role(role)
|
||||
t, err := time.Parse(time.RFC3339Nano, created)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse created_at: %w", err)
|
||||
}
|
||||
u.CreatedAt = t
|
||||
if lastLogin.Valid {
|
||||
t, err := time.Parse(time.RFC3339Nano, lastLogin.String)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("store: parse last_login_at: %w", err)
|
||||
}
|
||||
u.LastLoginAt = &t
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserCRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
u := User{
|
||||
ID: "u1",
|
||||
Username: "alice",
|
||||
PasswordHash: "$argon2id$...",
|
||||
Role: RoleAdmin,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := s.CreateUser(ctx, u); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetUserByUsername(ctx, "alice")
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got.ID != "u1" || got.Role != RoleAdmin {
|
||||
t.Errorf("unexpected user: %+v", got)
|
||||
}
|
||||
|
||||
// Username uniqueness is enforced by the schema.
|
||||
if err := s.CreateUser(ctx, u); err == nil {
|
||||
t.Error("duplicate username should fail")
|
||||
}
|
||||
|
||||
if _, err := s.GetUserByUsername(ctx, "bob"); !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("missing user: want ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
if err := s.MarkUserLogin(ctx, "u1", now); err != nil {
|
||||
t.Fatalf("mark login: %v", err)
|
||||
}
|
||||
got, _ = s.GetUserByUsername(ctx, "alice")
|
||||
if got.LastLoginAt == nil {
|
||||
t.Error("last_login_at not updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
n, _ := s.CountUsers(ctx)
|
||||
if n != 0 {
|
||||
t.Errorf("fresh db: want 0, got %d", n)
|
||||
}
|
||||
_ = s.CreateUser(ctx, User{
|
||||
ID: "u1", Username: "a", PasswordHash: "x",
|
||||
Role: RoleAdmin, CreatedAt: time.Now(),
|
||||
})
|
||||
n, _ = s.CountUsers(ctx)
|
||||
if n != 1 {
|
||||
t.Errorf("after insert: want 1, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Need a user for FK.
|
||||
_ = s.CreateUser(ctx, User{
|
||||
ID: "u1", Username: "alice", PasswordHash: "x",
|
||||
Role: RoleAdmin, CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
now := time.Now().UTC()
|
||||
sess := Session{
|
||||
UserID: "u1",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
IP: "10.0.0.1",
|
||||
UA: "test/1.0",
|
||||
}
|
||||
hash := "deadbeef" + "00000000000000000000000000000000000000000000000000000000"
|
||||
if err := s.CreateSession(ctx, sess, hash); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.LookupSession(ctx, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("lookup: %v", err)
|
||||
}
|
||||
if got.UserID != "u1" {
|
||||
t.Errorf("user mismatch: %s", got.UserID)
|
||||
}
|
||||
|
||||
// Expired sessions should not resolve.
|
||||
expiredHash := "expired-hash"
|
||||
expired := Session{
|
||||
UserID: "u1",
|
||||
CreatedAt: now.Add(-2 * time.Hour),
|
||||
ExpiresAt: now.Add(-time.Hour),
|
||||
}
|
||||
if err := s.CreateSession(ctx, expired, expiredHash); err != nil {
|
||||
t.Fatalf("create expired: %v", err)
|
||||
}
|
||||
if _, err := s.LookupSession(ctx, expiredHash); !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expired session should look like ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
if err := s.DeleteSession(ctx, hash); err != nil {
|
||||
t.Fatalf("delete: %v", err)
|
||||
}
|
||||
if _, err := s.LookupSession(ctx, hash); !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("deleted session: want ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
n, err := s.PurgeExpiredSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("purge: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("purge should remove the 1 expired row, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
hash := "tok-hash"
|
||||
if err := s.CreateEnrollmentToken(ctx, hash, time.Hour); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
// Need a host for FK.
|
||||
_, err := s.DB().Exec(`INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`,
|
||||
"h1", "host1", "linux", "amd64", time.Now().UTC().Format(time.RFC3339Nano))
|
||||
if err != nil {
|
||||
t.Fatalf("insert host: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); err != nil {
|
||||
t.Fatalf("consume: %v", err)
|
||||
}
|
||||
// Second consume must fail — the whole point of one-time tokens.
|
||||
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("re-consume: want ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user