phase 1 foundations: api types, store, crypto, auth

Lands the bottom three layers of Phase 1:

P1-08 internal/api: protocol_version + envelope + every WS message
  shape from spec.md §6.2 (Hello, Heartbeat, Job*, Schedule*, etc).
  Wire-format tests pin the JSON shape so a rename here breaks
  tests instead of silently breaking the agent.

P1-02 + P1-03 internal/store: SQLite via modernc.org/sqlite,
  embed.FS + a tiny version table for hand-rolled migrations.
  0001_initial.sql covers every table from spec.md §5 plus
  enrollment_tokens and host_schedule_version. Typed accessors
  for users / sessions / enrollment / audit. WAL + foreign_keys
  + busy_timeout on by default.

P1-06 internal/crypto: XChaCha20-Poly1305 AEAD wrapper with
  per-message random nonce. Key file lifecycle (generate +
  refuse-to-overwrite, load with size validation). Optional
  additionalData binds ciphertext to the row that owns it.

P1-04 internal/auth (partial — passwords + tokens; sessions
  middleware lands with the HTTP handlers): argon2id following
  RFC 9106 (64 MiB / t=3 / p=4 / 32B), constant-time verify.
  HashToken stores SHA-256 of session/agent/enrollment tokens
  so a stolen DB doesn't hand over credentials.

Build floor moves to Go 1.25 (modernc.org/sqlite v1.50+ requires
it); CI + Dockerfile + README updated. Markdown lint diagnostics
on tasks.md cleared.

All packages tested. ~70 new tests pass in <1s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-01 00:24:40 +01:00
parent c821ec1fe0
commit f55747a281
28 changed files with 1952 additions and 13 deletions
+36
View File
@@ -0,0 +1,36 @@
package store
import (
"context"
"encoding/json"
"fmt"
"time"
)
// AppendAudit records an audit log entry.
func (s *Store) AppendAudit(ctx context.Context, e AuditEntry) error {
if len(e.Payload) == 0 {
e.Payload = json.RawMessage("{}")
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO audit_log (id, user_id, actor, action, target_kind, target_id, ts, payload)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
e.ID, nullable(e.UserID), e.Actor, e.Action,
nullable(e.TargetKind), nullable(e.TargetID),
e.TS.UTC().Format(time.RFC3339Nano),
string(e.Payload))
if err != nil {
return fmt.Errorf("store: append audit: %w", err)
}
return nil
}
// nullable returns nil for nil/empty *string so SQLite stores NULL.
// SQLite's driver treats Go nil as NULL but treats *string("") as ''.
// We want NULL semantics for "absent."
func nullable(p *string) any {
if p == nil || *p == "" {
return nil
}
return *p
}
-3
View File
@@ -1,3 +0,0 @@
// Package store is the SQLite persistence layer
// (modernc.org/sqlite, no CGo).
package store
+58
View File
@@ -0,0 +1,58 @@
package store
import (
"context"
"fmt"
"time"
)
// CreateEnrollmentToken persists a fresh one-time token. The caller
// has already hashed the raw token; the raw form is returned to the
// operator (printed in the install snippet) and never persisted.
func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration) error {
now := time.Now().UTC()
_, err := s.db.ExecContext(ctx,
`INSERT INTO enrollment_tokens (token_hash, created_at, expires_at)
VALUES (?, ?, ?)`,
tokenHash,
now.Format(time.RFC3339Nano),
now.Add(ttl).Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("store: create enrollment token: %w", err)
}
return nil
}
// ConsumeEnrollmentToken atomically validates a token (must exist,
// not be consumed, not be expired) and marks it consumed by hostID.
// Returns ErrNotFound on any failure.
func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID string) error {
now := time.Now().UTC().Format(time.RFC3339Nano)
res, err := s.db.ExecContext(ctx,
`UPDATE enrollment_tokens
SET consumed_at = ?, consumed_host = ?
WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`,
now, hostID, tokenHash, now)
if err != nil {
return fmt.Errorf("store: consume enrollment token: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
// PurgeExpiredEnrollmentTokens deletes long-expired token rows. Tokens
// retained for ~24h after expiry so audit traces still resolve them.
func (s *Store) PurgeExpiredEnrollmentTokens(ctx context.Context) (int64, error) {
cutoff := time.Now().Add(-24 * time.Hour).UTC().Format(time.RFC3339Nano)
res, err := s.db.ExecContext(ctx,
`DELETE FROM enrollment_tokens WHERE expires_at <= ?`, cutoff)
if err != nil {
return 0, fmt.Errorf("store: purge enrollment tokens: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
+100
View File
@@ -0,0 +1,100 @@
package store
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"sort"
"strings"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// migration is one ordered SQL file from migrations/.
type migration struct {
version int // parsed from filename prefix (0001, 0002, …)
name string // full filename, for error messages
sql string
}
// loadMigrations reads every migrations/*.sql file in lexical order
// and returns them. Filenames must look like NNNN_name.sql; the
// numeric prefix is the version.
func loadMigrations() ([]migration, error) {
entries, err := fs.ReadDir(migrationsFS, "migrations")
if err != nil {
return nil, fmt.Errorf("read migrations dir: %w", err)
}
out := make([]migration, 0, len(entries))
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
continue
}
var v int
// Allow up to 6 digits (we will never need that many but it
// costs nothing to be permissive).
if _, err := fmt.Sscanf(e.Name(), "%d_", &v); err != nil {
return nil, fmt.Errorf("migration %q: cannot parse version prefix: %w", e.Name(), err)
}
body, err := fs.ReadFile(migrationsFS, "migrations/"+e.Name())
if err != nil {
return nil, fmt.Errorf("read %s: %w", e.Name(), err)
}
out = append(out, migration{version: v, name: e.Name(), sql: string(body)})
}
sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version })
return out, nil
}
// migrate brings the db up to the highest known version. It is
// idempotent: running it on an already-current db is a no-op. There
// is no rollback path; we move forward only.
func migrate(ctx context.Context, db *sql.DB) error {
if _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER PRIMARY KEY,
applied_at TEXT NOT NULL
)
`); err != nil {
return fmt.Errorf("create schema_version: %w", err)
}
migs, err := loadMigrations()
if err != nil {
return err
}
for _, m := range migs {
var applied int
row := db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM schema_version WHERE version = ?`, m.version)
if err := row.Scan(&applied); err != nil {
return fmt.Errorf("check version %d: %w", m.version, err)
}
if applied > 0 {
continue
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx for migration %s: %w", m.name, err)
}
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply %s: %w", m.name, err)
}
if _, err := tx.ExecContext(ctx,
`INSERT INTO schema_version (version, applied_at) VALUES (?, datetime('now'))`,
m.version); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record %s: %w", m.name, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit %s: %w", m.name, err)
}
}
return nil
}
+199
View File
@@ -0,0 +1,199 @@
-- 0001_initial.sql
--
-- Initial schema for restic-manager. Mirrors the domain model in
-- spec.md §5. We use TEXT primary keys (ULIDs) throughout: sortable,
-- URL-safe, no autoincrement contention. JSON blobs are stored as
-- TEXT; SQLite's json1 extension is available but we read/write
-- raw and parse in Go for portability.
--
-- All timestamps are stored as RFC 3339 TEXT (UTC). SQLite's INTEGER
-- (unix epoch) would be cheaper but text is human-readable in dumps
-- and the storage cost is negligible at this scale.
CREATE TABLE users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('admin','operator','viewer')),
created_at TEXT NOT NULL,
last_login_at TEXT
);
CREATE TABLE sessions (
id TEXT PRIMARY KEY, -- session token (high-entropy)
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
ip TEXT,
ua TEXT
);
CREATE INDEX sessions_user_id ON sessions(user_id);
CREATE INDEX sessions_expires_at ON sessions(expires_at);
CREATE TABLE credentials (
id TEXT PRIMARY KEY,
kind TEXT NOT NULL, -- 'rest','s3','local'
username TEXT,
-- secret_ref is the AEAD ciphertext (nonce || ciphertext, base64).
-- The plaintext never lands on disk.
secret_ref TEXT NOT NULL,
rotated_at TEXT NOT NULL
);
CREATE TABLE repos (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
url TEXT NOT NULL,
kind TEXT NOT NULL CHECK (kind IN ('rest','s3','local')),
credential_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT,
password_secret_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT,
-- Cached projection from `restic stats` + lock-file inspection.
size_bytes INTEGER NOT NULL DEFAULT 0,
snapshot_count INTEGER NOT NULL DEFAULT 0,
dedup_ratio REAL NOT NULL DEFAULT 0,
last_check_at TEXT,
last_check_status TEXT,
lock_state TEXT NOT NULL DEFAULT 'unlocked'
CHECK (lock_state IN ('locked','unlocked')),
append_only INTEGER NOT NULL DEFAULT 1, -- bool
credential_rotated_at TEXT
);
CREATE TABLE hosts (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
os TEXT NOT NULL,
arch TEXT NOT NULL,
agent_version TEXT NOT NULL DEFAULT '',
restic_version TEXT NOT NULL DEFAULT '',
protocol_version INTEGER NOT NULL DEFAULT 0,
enrolled_at TEXT NOT NULL,
last_seen_at TEXT,
status TEXT NOT NULL DEFAULT 'offline'
CHECK (status IN ('online','offline','degraded')),
repo_id TEXT REFERENCES repos(id) ON DELETE SET NULL,
tags TEXT NOT NULL DEFAULT '[]', -- json array
current_job_id TEXT,
-- Denormalised projections (refreshed on job.finished etc).
last_backup_at TEXT,
last_backup_status TEXT
CHECK (last_backup_status IN
('succeeded','failed','cancelled') OR
last_backup_status IS NULL),
repo_size_bytes INTEGER NOT NULL DEFAULT 0,
snapshot_count INTEGER NOT NULL DEFAULT 0,
open_alert_count INTEGER NOT NULL DEFAULT 0,
applied_schedule_version INTEGER NOT NULL DEFAULT 0,
-- Server-issued credentials for the agent ↔ server WS.
agent_token_hash TEXT NOT NULL DEFAULT '',
cert_pin_sha256 TEXT NOT NULL DEFAULT ''
);
CREATE INDEX hosts_status ON hosts(status);
CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at);
-- Pending one-time enrollment tokens (TTL'd, single-use).
CREATE TABLE enrollment_tokens (
token_hash TEXT PRIMARY KEY, -- argon2id of token
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
consumed_at TEXT,
consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL
);
CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at);
CREATE TABLE schedules (
id TEXT PRIMARY KEY,
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check')),
cron_expr TEXT NOT NULL,
paths TEXT NOT NULL DEFAULT '[]', -- json array
excludes TEXT NOT NULL DEFAULT '[]',
tags TEXT NOT NULL DEFAULT '[]',
retention_policy TEXT NOT NULL DEFAULT '{}', -- json object
options TEXT NOT NULL DEFAULT '{}', -- json object (bandwidth)
-- Hooks are encrypted at rest (AEAD ciphertext). Constraint enforced
-- in application code: hooks must be empty unless kind='backup'.
pre_hook TEXT NOT NULL DEFAULT '',
post_hook TEXT NOT NULL DEFAULT '',
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX schedules_host_id ON schedules(host_id);
-- Per-host monotonic schedule version. Bumped on any schedules INSERT/
-- UPDATE/DELETE for that host. Pushed to the agent in schedule.set;
-- the agent acks back the same version in schedule.ack.
CREATE TABLE host_schedule_version (
host_id TEXT PRIMARY KEY REFERENCES hosts(id) ON DELETE CASCADE,
version INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE jobs (
id TEXT PRIMARY KEY,
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check','unlock')),
status TEXT NOT NULL CHECK (status IN ('queued','running','succeeded','failed','cancelled')),
scheduled_id TEXT REFERENCES schedules(id) ON DELETE SET NULL,
actor_kind TEXT NOT NULL CHECK (actor_kind IN ('user','schedule','system')),
actor_id TEXT, -- user id, schedule id, or null
started_at TEXT,
finished_at TEXT,
exit_code INTEGER,
stats TEXT, -- json blob from restic
error TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX jobs_host_id ON jobs(host_id);
CREATE INDEX jobs_status ON jobs(status);
CREATE INDEX jobs_created_at ON jobs(created_at);
CREATE TABLE job_logs (
job_id TEXT NOT NULL REFERENCES jobs(id) ON DELETE CASCADE,
seq INTEGER NOT NULL,
ts TEXT NOT NULL,
stream TEXT NOT NULL CHECK (stream IN ('stdout','stderr','event')),
payload TEXT NOT NULL,
PRIMARY KEY (job_id, seq)
);
CREATE TABLE snapshots (
id TEXT PRIMARY KEY, -- restic snapshot id
host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE,
repo_id TEXT NOT NULL REFERENCES repos(id) ON DELETE CASCADE,
time TEXT NOT NULL,
hostname TEXT NOT NULL,
paths TEXT NOT NULL DEFAULT '[]',
tags TEXT NOT NULL DEFAULT '[]',
size_bytes INTEGER NOT NULL DEFAULT 0,
file_count INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX snapshots_host_id ON snapshots(host_id);
CREATE INDEX snapshots_time ON snapshots(time);
CREATE TABLE alerts (
id TEXT PRIMARY KEY,
host_id TEXT REFERENCES hosts(id) ON DELETE CASCADE,
kind TEXT NOT NULL,
severity TEXT NOT NULL CHECK (severity IN ('info','warning','critical')),
message TEXT NOT NULL,
created_at TEXT NOT NULL,
acknowledged_at TEXT,
acknowledged_by TEXT REFERENCES users(id) ON DELETE SET NULL,
resolved_at TEXT
);
CREATE INDEX alerts_host_id ON alerts(host_id);
CREATE INDEX alerts_open ON alerts(host_id) WHERE resolved_at IS NULL;
CREATE TABLE audit_log (
id TEXT PRIMARY KEY,
user_id TEXT REFERENCES users(id) ON DELETE SET NULL,
actor TEXT NOT NULL CHECK (actor IN ('user','agent','system')),
action TEXT NOT NULL,
target_kind TEXT,
target_id TEXT,
ts TEXT NOT NULL,
payload TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX audit_log_ts ON audit_log(ts);
CREATE INDEX audit_log_user ON audit_log(user_id);
+88
View File
@@ -0,0 +1,88 @@
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)
// CreateSession persists a session row. The token is hashed before
// insert; the raw token is what the caller hands to the user (cookie).
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
_, err := s.db.ExecContext(ctx,
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua)
VALUES (?, ?, ?, ?, ?, ?)`,
tokenHash,
sess.UserID,
sess.CreatedAt.UTC().Format(time.RFC3339Nano),
sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
sess.IP, sess.UA)
if err != nil {
return fmt.Errorf("store: create session: %w", err)
}
return nil
}
// LookupSession resolves a token hash to a session row, returning
// ErrNotFound if the hash is unknown OR the session has expired.
// We collapse "no row" and "expired" to the same error so the caller
// can't tell them apart in error messages — that prevents enumeration
// of valid token hashes.
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, user_id, created_at, expires_at, ip, ua
FROM sessions
WHERE id = ? AND expires_at > ?`,
tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
var sess Session
var created, expires string
var ip, ua sql.NullString
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("store: lookup session: %w", err)
}
t, err := time.Parse(time.RFC3339Nano, created)
if err != nil {
return nil, fmt.Errorf("store: parse created_at: %w", err)
}
sess.CreatedAt = t
t, err = time.Parse(time.RFC3339Nano, expires)
if err != nil {
return nil, fmt.Errorf("store: parse expires_at: %w", err)
}
sess.ExpiresAt = t
if ip.Valid {
sess.IP = ip.String
}
if ua.Valid {
sess.UA = ua.String
}
return &sess, nil
}
// DeleteSession removes a session row by token hash. Used on logout.
func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, tokenHash)
if err != nil {
return fmt.Errorf("store: delete session: %w", err)
}
return nil
}
// PurgeExpiredSessions deletes session rows past their expires_at.
// Run periodically from a background goroutine.
func (s *Store) PurgeExpiredSessions(ctx context.Context) (int64, error) {
res, err := s.db.ExecContext(ctx,
`DELETE FROM sessions WHERE expires_at <= ?`,
time.Now().UTC().Format(time.RFC3339Nano))
if err != nil {
return 0, fmt.Errorf("store: purge sessions: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
+84
View File
@@ -0,0 +1,84 @@
// Package store is the SQLite persistence layer (modernc.org/sqlite,
// no CGo). It owns the schema, exposes typed accessors, and hides
// the database/sql plumbing from the rest of the server.
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"time"
_ "modernc.org/sqlite" // register the "sqlite" driver
)
// ErrNotFound is returned by accessors when a lookup misses.
var ErrNotFound = errors.New("store: not found")
// Store is a thin wrapper around *sql.DB that exposes the typed
// accessors used by the rest of the server. Callers should use the
// provided methods rather than reaching into DB() directly.
type Store struct {
db *sql.DB
}
// Open opens (or creates) the SQLite database at path, applies all
// pending migrations, and returns a ready-to-use Store.
//
// The DSN sets:
// - _pragma=foreign_keys(1) — referential integrity is on
// - _pragma=journal_mode(WAL) — concurrent reads vs writes
// - _pragma=busy_timeout(5000) — wait 5s on lock contention
// - _time_format=sqlite — RFC 3339 read/write of TEXT timestamps
//
// Empty path uses an in-memory DB (useful for tests).
func Open(ctx context.Context, path string) (*Store, error) {
dsn := buildDSN(path)
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("open %q: %w", path, err)
}
// modernc.org/sqlite is not safe for arbitrary high parallelism on
// a single file. WAL helps, but 1 writer + multiple readers is the
// only safe shape. Cap connections to keep that property explicit.
db.SetMaxOpenConns(8)
db.SetMaxIdleConns(4)
db.SetConnMaxLifetime(time.Hour)
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("ping: %w", err)
}
if err := migrate(ctx, db); err != nil {
_ = db.Close()
return nil, fmt.Errorf("migrate: %w", err)
}
return &Store{db: db}, nil
}
// Close releases the underlying DB handle.
func (s *Store) Close() error { return s.db.Close() }
// DB returns the underlying *sql.DB. Reserved for tests and migrations
// — production code should add a typed method to this package instead.
func (s *Store) DB() *sql.DB { return s.db }
func buildDSN(path string) string {
if path == "" {
// Shared cache + named in-memory db so multiple connections see
// the same data — needed because we cap MaxOpenConns above.
return "file::memory:?cache=shared&_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)"
}
q := url.Values{}
q.Add("_pragma", "foreign_keys(1)")
q.Add("_pragma", "journal_mode(WAL)")
q.Add("_pragma", "busy_timeout(5000)")
q.Add("_pragma", "synchronous(NORMAL)")
return "file:" + path + "?" + q.Encode()
}
+93
View File
@@ -0,0 +1,93 @@
package store
import (
"context"
"path/filepath"
"testing"
)
// openTestStore opens an isolated file-backed db in a t.TempDir.
// In-memory + shared-cache works too but file makes failures easier
// to inspect when a test panics.
func openTestStore(t *testing.T) *Store {
t.Helper()
dir := t.TempDir()
s, err := Open(context.Background(), filepath.Join(dir, "rm.db"))
if err != nil {
t.Fatalf("open: %v", err)
}
t.Cleanup(func() { _ = s.Close() })
return s
}
func TestOpenAppliesMigrations(t *testing.T) {
t.Parallel()
s := openTestStore(t)
row := s.DB().QueryRow(`SELECT MAX(version) FROM schema_version`)
var v int
if err := row.Scan(&v); err != nil {
t.Fatalf("scan: %v", err)
}
if v < 1 {
t.Fatalf("expected at least migration 1 applied, got %d", v)
}
// Spot-check a few tables exist with expected columns.
tables := []string{"users", "sessions", "hosts", "repos",
"credentials", "schedules", "jobs", "job_logs",
"snapshots", "alerts", "audit_log",
"enrollment_tokens", "host_schedule_version"}
for _, tbl := range tables {
row := s.DB().QueryRow(
`SELECT name FROM sqlite_master WHERE type='table' AND name = ?`, tbl)
var got string
if err := row.Scan(&got); err != nil {
t.Errorf("table %q missing: %v", tbl, err)
}
}
}
func TestMigrateIsIdempotent(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "rm.db")
for i := 0; i < 3; i++ {
s, err := Open(context.Background(), path)
if err != nil {
t.Fatalf("open #%d: %v", i, err)
}
_ = s.Close()
}
s, err := Open(context.Background(), path)
if err != nil {
t.Fatalf("final open: %v", err)
}
defer s.Close()
row := s.DB().QueryRow(`SELECT COUNT(*) FROM schema_version`)
var n int
if err := row.Scan(&n); err != nil {
t.Fatalf("scan: %v", err)
}
if n != 1 {
t.Errorf("re-running migrations should not insert duplicate rows; got %d", n)
}
}
func TestForeignKeysEnforced(t *testing.T) {
t.Parallel()
s := openTestStore(t)
// Inserting a session with a non-existent user should fail because
// FKs are on. Without the pragma, SQLite silently accepts this.
_, err := s.DB().Exec(
`INSERT INTO sessions (id, user_id, created_at, expires_at)
VALUES (?, ?, datetime('now'), datetime('now','+1 hour'))`,
"sess1", "no-such-user")
if err == nil {
t.Fatal("expected FK violation, got nil")
}
}
+82
View File
@@ -0,0 +1,82 @@
package store
import (
"encoding/json"
"time"
)
// User mirrors the users table.
type User struct {
ID string
Username string
PasswordHash string
Role Role
CreatedAt time.Time
LastLoginAt *time.Time
}
// Role enumerates the access tiers from spec.md §7.2.
type Role string
const (
RoleAdmin Role = "admin"
RoleOperator Role = "operator"
RoleViewer Role = "viewer"
)
// Session mirrors the sessions table. The ID is the (raw) session
// token; the DB stores its hash. Callers that hold a *Session have
// already authenticated.
type Session struct {
ID string // session token (raw); never persisted as-is
UserID string
CreatedAt time.Time
ExpiresAt time.Time
IP string
UA string
}
// Host mirrors the denormalised hosts table. JSON columns (tags) are
// returned decoded into Go slices for ergonomics.
type Host struct {
ID string
Name string
OS string
Arch string
AgentVersion string
ResticVersion string
ProtocolVersion int
EnrolledAt time.Time
LastSeenAt *time.Time
Status string
RepoID *string
Tags []string
CurrentJobID *string
LastBackupAt *time.Time
LastBackupStatus *string
RepoSizeBytes int64
SnapshotCount int
OpenAlertCount int
AppliedScheduleVersion int64
}
// EnrollmentToken is the issuer's view of a one-time token. The
// raw token is returned only at create time; the DB stores its hash.
type EnrollmentToken struct {
Raw string // populated on create only
TokenHash string
CreatedAt time.Time
ExpiresAt time.Time
}
// AuditEntry mirrors the audit_log table.
type AuditEntry struct {
ID string
UserID *string
Actor string // user|agent|system
Action string
TargetKind *string
TargetID *string
TS time.Time
Payload json.RawMessage
}
+87
View File
@@ -0,0 +1,87 @@
package store
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)
// CreateUser inserts a new user. The caller is responsible for
// generating an ID (typically a ULID) and hashing the password.
func (s *Store) CreateUser(ctx context.Context, u User) error {
_, err := s.db.ExecContext(ctx,
`INSERT INTO users (id, username, password_hash, role, created_at)
VALUES (?, ?, ?, ?, ?)`,
u.ID, u.Username, u.PasswordHash, string(u.Role), u.CreatedAt.UTC().Format(time.RFC3339Nano))
if err != nil {
return fmt.Errorf("store: create user: %w", err)
}
return nil
}
// GetUserByUsername looks up a user by their (case-sensitive) username.
// Returns ErrNotFound if no row matches.
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, created_at, last_login_at
FROM users WHERE username = ?`, username)
return scanUser(row)
}
// GetUserByID looks up a user by id. Returns ErrNotFound on miss.
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
row := s.db.QueryRowContext(ctx,
`SELECT id, username, password_hash, role, created_at, last_login_at
FROM users WHERE id = ?`, id)
return scanUser(row)
}
// CountUsers returns the total number of user rows. The first-run
// bootstrap uses this to detect a fresh install.
func (s *Store) CountUsers(ctx context.Context) (int, error) {
var n int
if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil {
return 0, fmt.Errorf("store: count users: %w", err)
}
return n, nil
}
// MarkUserLogin records a successful authentication.
func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error {
_, err := s.db.ExecContext(ctx,
`UPDATE users SET last_login_at = ? WHERE id = ?`,
when.UTC().Format(time.RFC3339Nano), id)
if err != nil {
return fmt.Errorf("store: mark login: %w", err)
}
return nil
}
func scanUser(row *sql.Row) (*User, error) {
var u User
var role string
var lastLogin sql.NullString
var created string
if err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("store: scan user: %w", err)
}
u.Role = Role(role)
t, err := time.Parse(time.RFC3339Nano, created)
if err != nil {
return nil, fmt.Errorf("store: parse created_at: %w", err)
}
u.CreatedAt = t
if lastLogin.Valid {
t, err := time.Parse(time.RFC3339Nano, lastLogin.String)
if err != nil {
return nil, fmt.Errorf("store: parse last_login_at: %w", err)
}
u.LastLoginAt = &t
}
return &u, nil
}
+158
View File
@@ -0,0 +1,158 @@
package store
import (
"context"
"errors"
"testing"
"time"
)
func TestUserCRUD(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
now := time.Now().UTC()
u := User{
ID: "u1",
Username: "alice",
PasswordHash: "$argon2id$...",
Role: RoleAdmin,
CreatedAt: now,
}
if err := s.CreateUser(ctx, u); err != nil {
t.Fatalf("create: %v", err)
}
got, err := s.GetUserByUsername(ctx, "alice")
if err != nil {
t.Fatalf("get: %v", err)
}
if got.ID != "u1" || got.Role != RoleAdmin {
t.Errorf("unexpected user: %+v", got)
}
// Username uniqueness is enforced by the schema.
if err := s.CreateUser(ctx, u); err == nil {
t.Error("duplicate username should fail")
}
if _, err := s.GetUserByUsername(ctx, "bob"); !errors.Is(err, ErrNotFound) {
t.Errorf("missing user: want ErrNotFound, got %v", err)
}
if err := s.MarkUserLogin(ctx, "u1", now); err != nil {
t.Fatalf("mark login: %v", err)
}
got, _ = s.GetUserByUsername(ctx, "alice")
if got.LastLoginAt == nil {
t.Error("last_login_at not updated")
}
}
func TestCountUsers(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
n, _ := s.CountUsers(ctx)
if n != 0 {
t.Errorf("fresh db: want 0, got %d", n)
}
_ = s.CreateUser(ctx, User{
ID: "u1", Username: "a", PasswordHash: "x",
Role: RoleAdmin, CreatedAt: time.Now(),
})
n, _ = s.CountUsers(ctx)
if n != 1 {
t.Errorf("after insert: want 1, got %d", n)
}
}
func TestSessionLifecycle(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
// Need a user for FK.
_ = s.CreateUser(ctx, User{
ID: "u1", Username: "alice", PasswordHash: "x",
Role: RoleAdmin, CreatedAt: time.Now(),
})
now := time.Now().UTC()
sess := Session{
UserID: "u1",
CreatedAt: now,
ExpiresAt: now.Add(time.Hour),
IP: "10.0.0.1",
UA: "test/1.0",
}
hash := "deadbeef" + "00000000000000000000000000000000000000000000000000000000"
if err := s.CreateSession(ctx, sess, hash); err != nil {
t.Fatalf("create: %v", err)
}
got, err := s.LookupSession(ctx, hash)
if err != nil {
t.Fatalf("lookup: %v", err)
}
if got.UserID != "u1" {
t.Errorf("user mismatch: %s", got.UserID)
}
// Expired sessions should not resolve.
expiredHash := "expired-hash"
expired := Session{
UserID: "u1",
CreatedAt: now.Add(-2 * time.Hour),
ExpiresAt: now.Add(-time.Hour),
}
if err := s.CreateSession(ctx, expired, expiredHash); err != nil {
t.Fatalf("create expired: %v", err)
}
if _, err := s.LookupSession(ctx, expiredHash); !errors.Is(err, ErrNotFound) {
t.Errorf("expired session should look like ErrNotFound, got %v", err)
}
if err := s.DeleteSession(ctx, hash); err != nil {
t.Fatalf("delete: %v", err)
}
if _, err := s.LookupSession(ctx, hash); !errors.Is(err, ErrNotFound) {
t.Errorf("deleted session: want ErrNotFound, got %v", err)
}
n, err := s.PurgeExpiredSessions(ctx)
if err != nil {
t.Fatalf("purge: %v", err)
}
if n != 1 {
t.Errorf("purge should remove the 1 expired row, got %d", n)
}
}
func TestEnrollmentTokenSingleUse(t *testing.T) {
t.Parallel()
s := openTestStore(t)
ctx := context.Background()
hash := "tok-hash"
if err := s.CreateEnrollmentToken(ctx, hash, time.Hour); err != nil {
t.Fatalf("create: %v", err)
}
// Need a host for FK.
_, err := s.DB().Exec(`INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`,
"h1", "host1", "linux", "amd64", time.Now().UTC().Format(time.RFC3339Nano))
if err != nil {
t.Fatalf("insert host: %v", err)
}
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); err != nil {
t.Fatalf("consume: %v", err)
}
// Second consume must fail — the whole point of one-time tokens.
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); !errors.Is(err, ErrNotFound) {
t.Errorf("re-consume: want ErrNotFound, got %v", err)
}
}