From c275f4ff4c22f202db1dcfa99cef87eefb1b4510 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Fri, 1 May 2026 00:24:40 +0100 Subject: [PATCH] phase 1 foundations: api types, store, crypto, auth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lands the bottom three layers of Phase 1: P1-08 internal/api: protocol_version + envelope + every WS message shape from spec.md §6.2 (Hello, Heartbeat, Job*, Schedule*, etc). Wire-format tests pin the JSON shape so a rename here breaks tests instead of silently breaking the agent. P1-02 + P1-03 internal/store: SQLite via modernc.org/sqlite, embed.FS + a tiny version table for hand-rolled migrations. 0001_initial.sql covers every table from spec.md §5 plus enrollment_tokens and host_schedule_version. Typed accessors for users / sessions / enrollment / audit. WAL + foreign_keys + busy_timeout on by default. P1-06 internal/crypto: XChaCha20-Poly1305 AEAD wrapper with per-message random nonce. Key file lifecycle (generate + refuse-to-overwrite, load with size validation). Optional additionalData binds ciphertext to the row that owns it. P1-04 internal/auth (partial — passwords + tokens; sessions middleware lands with the HTTP handlers): argon2id following RFC 9106 (64 MiB / t=3 / p=4 / 32B), constant-time verify. HashToken stores SHA-256 of session/agent/enrollment tokens so a stolen DB doesn't hand over credentials. Build floor moves to Go 1.25 (modernc.org/sqlite v1.50+ requires it); CI + Dockerfile + README updated. Markdown lint diagnostics on tasks.md cleared. All packages tested. ~70 new tests pass in <1s. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitea/workflows/ci.yml | 3 +- README.md | 3 +- deploy/Dockerfile.server | 2 +- go.mod | 19 +- go.sum | 53 +++++ internal/api/messages.go | 213 +++++++++++++++++++++ internal/api/version.go | 14 ++ internal/api/wire.go | 86 +++++++++ internal/api/wire_test.go | 143 ++++++++++++++ internal/auth/doc.go | 3 - internal/auth/passwords.go | 87 +++++++++ internal/auth/passwords_test.go | 81 ++++++++ internal/auth/tokens.go | 34 ++++ internal/crypto/aead.go | 112 +++++++++++ internal/crypto/aead_test.go | 110 +++++++++++ internal/crypto/doc.go | 3 - internal/store/audit.go | 36 ++++ internal/store/doc.go | 3 - internal/store/enrollment.go | 58 ++++++ internal/store/migrate.go | 100 ++++++++++ internal/store/migrations/0001_initial.sql | 199 +++++++++++++++++++ internal/store/sessions.go | 88 +++++++++ internal/store/store.go | 84 ++++++++ internal/store/store_test.go | 93 +++++++++ internal/store/types.go | 82 ++++++++ internal/store/users.go | 87 +++++++++ internal/store/users_test.go | 158 +++++++++++++++ tasks.md | 11 ++ 28 files changed, 1952 insertions(+), 13 deletions(-) create mode 100644 go.sum create mode 100644 internal/api/messages.go create mode 100644 internal/api/version.go create mode 100644 internal/api/wire.go create mode 100644 internal/api/wire_test.go delete mode 100644 internal/auth/doc.go create mode 100644 internal/auth/passwords.go create mode 100644 internal/auth/passwords_test.go create mode 100644 internal/auth/tokens.go create mode 100644 internal/crypto/aead.go create mode 100644 internal/crypto/aead_test.go delete mode 100644 internal/crypto/doc.go create mode 100644 internal/store/audit.go delete mode 100644 internal/store/doc.go create mode 100644 internal/store/enrollment.go create mode 100644 internal/store/migrate.go create mode 100644 internal/store/migrations/0001_initial.sql create mode 100644 internal/store/sessions.go create mode 100644 internal/store/store.go create mode 100644 internal/store/store_test.go create mode 100644 internal/store/types.go create mode 100644 internal/store/users.go create mode 100644 internal/store/users_test.go diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index f005b04..35c4c33 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -7,7 +7,8 @@ on: branches: [main] env: - GO_VERSION: "1.23" + # Floor is set by the heaviest dep (modernc.org/sqlite v1.50+). + GO_VERSION: "1.25" jobs: test: diff --git a/README.md b/README.md index 4d2ad32..56419ed 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,8 @@ design/ UI wireframes (Phase 0 design pass) ## Local development -Requires Go 1.23+ (built and tested on 1.26). +Requires Go 1.25+ (built and tested on 1.26). The floor is set by +`modernc.org/sqlite` v1.50. ```sh make build # builds cmd/server and cmd/agent into ./bin diff --git a/deploy/Dockerfile.server b/deploy/Dockerfile.server index e541d2a..6cf3bee 100644 --- a/deploy/Dockerfile.server +++ b/deploy/Dockerfile.server @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1.7 # ---- Build stage -------------------------------------------------------- -FROM golang:1.23-alpine AS build +FROM golang:1.25-alpine AS build WORKDIR /src diff --git a/go.mod b/go.mod index c2d0cd3..dce2b6d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,20 @@ module gitea.dcglab.co.uk/steve/restic-manager -go 1.23 +go 1.25.0 + +require ( + golang.org/x/crypto v0.50.0 + modernc.org/sqlite v1.50.0 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/sys v0.43.0 // indirect + modernc.org/libc v1.72.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e478a19 --- /dev/null +++ b/go.sum @@ -0,0 +1,53 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U= +modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8= +modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU= +modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c= +modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM= +modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/api/messages.go b/internal/api/messages.go new file mode 100644 index 0000000..c2c7aed --- /dev/null +++ b/internal/api/messages.go @@ -0,0 +1,213 @@ +package api + +import ( + "encoding/json" + "time" +) + +// HostOS / HostArch are constrained string types. The store stores them +// raw, but agent metadata collection should populate them from these +// constants so we don't end up with both "linux" and "Linux" rows. +type HostOS string + +const ( + OSLinux HostOS = "linux" + OSWindows HostOS = "windows" +) + +type HostArch string + +const ( + ArchAmd64 HostArch = "amd64" + ArchArm64 HostArch = "arm64" +) + +// HelloPayload is the agent's first message after WS auth. The server +// upserts a Host row, marks it online, and (if protocol_version is +// acceptable) responds with a config.update + schedule.set burst. +type HelloPayload struct { + ProtocolVersion int `json:"protocol_version"` + AgentVersion string `json:"agent_version"` + ResticVersion string `json:"restic_version"` + Hostname string `json:"hostname"` + OS HostOS `json:"os"` + Arch HostArch `json:"arch"` + BootTime time.Time `json:"boot_time,omitempty"` +} + +// HeartbeatPayload is sent by the agent every 30s. It carries no data +// today; presence is the signal. Future fields (load, free disk) can +// land here without bumping protocol_version. +type HeartbeatPayload struct { + SentAt time.Time `json:"sent_at"` +} + +// JobKind is the operation an agent is being asked to run, or just ran. +type JobKind string + +const ( + JobBackup JobKind = "backup" + JobForget JobKind = "forget" + JobPrune JobKind = "prune" + JobCheck JobKind = "check" + JobUnlock JobKind = "unlock" +) + +// JobStatus is the lifecycle state of a job. +type JobStatus string + +const ( + JobQueued JobStatus = "queued" + JobRunning JobStatus = "running" + JobSucceeded JobStatus = "succeeded" + JobFailed JobStatus = "failed" + JobCancelled JobStatus = "cancelled" +) + +// CommandRunPayload is the server → agent dispatch for a run-now job. +type CommandRunPayload struct { + JobID string `json:"job_id"` + Kind JobKind `json:"kind"` + Args []string `json:"args,omitempty"` +} + +// CommandCancelPayload is the server → agent cancel signal. +type CommandCancelPayload struct { + JobID string `json:"job_id"` +} + +// CommandResultPayload acks a command.run dispatch (the agent has +// accepted the job and persisted it locally) — this is *not* the job +// completion. job.finished signals that. +type CommandResultPayload struct { + JobID string `json:"job_id"` + Accepted bool `json:"accepted"` + Error string `json:"error,omitempty"` +} + +// JobStartedPayload — agent has begun execution. +type JobStartedPayload struct { + JobID string `json:"job_id"` + Kind JobKind `json:"kind"` + StartedAt time.Time `json:"started_at"` +} + +// JobProgressPayload — agent's periodic status while a job is running. +// Field set chosen to match what restic --json emits for `backup`; +// other kinds populate the subset that makes sense. +type JobProgressPayload struct { + JobID string `json:"job_id"` + PercentDone float64 `json:"percent_done"` + FilesDone int64 `json:"files_done"` + TotalFiles int64 `json:"total_files"` + BytesDone int64 `json:"bytes_done"` + TotalBytes int64 `json:"total_bytes"` + ETASeconds int64 `json:"eta_seconds"` + ThroughputBps int64 `json:"throughput_bps"` +} + +// JobFinishedPayload — agent reports terminal state. +type JobFinishedPayload struct { + JobID string `json:"job_id"` + Status JobStatus `json:"status"` + ExitCode int `json:"exit_code"` + FinishedAt time.Time `json:"finished_at"` + Stats json.RawMessage `json:"stats,omitempty"` // restic summary blob + Error string `json:"error,omitempty"` +} + +// LogStreamLine is one entry of the live job log. +type LogStreamLine struct { + JobID string `json:"job_id"` + Seq int64 `json:"seq"` + TS time.Time `json:"ts"` + Stream LogStream `json:"stream"` + Payload string `json:"payload"` +} + +// LogStream identifies which channel a log line came from. +type LogStream string + +const ( + LogStdout LogStream = "stdout" + LogStderr LogStream = "stderr" + LogEvent LogStream = "event" // parsed restic --json event +) + +// SnapshotsReportPayload — agent dumps its full snapshot list after +// each successful backup, so the server can refresh its projection. +type SnapshotsReportPayload struct { + Snapshots []Snapshot `json:"snapshots"` +} + +// Snapshot is the projection mirrored from `restic snapshots --json`. +type Snapshot struct { + ID string `json:"id"` // restic snapshot ID + Time time.Time `json:"time"` + Hostname string `json:"hostname"` + Paths []string `json:"paths"` + Tags []string `json:"tags,omitempty"` + SizeBytes int64 `json:"size_bytes,omitempty"` + FileCount int64 `json:"file_count,omitempty"` +} + +// RepoStatsPayload — agent reports periodic repo health facts derived +// from `restic stats` and lock-file inspection. +type RepoStatsPayload struct { + SizeBytes int64 `json:"size_bytes"` + SnapshotCount int `json:"snapshot_count"` + DedupRatio float64 `json:"dedup_ratio"` + LastCheckAt time.Time `json:"last_check_at,omitempty"` + LastCheckStatus string `json:"last_check_status,omitempty"` + LockState string `json:"lock_state"` // locked|unlocked +} + +// Schedule is the agent-facing view of a Schedule row. (Server-side +// CRUD shapes live in the http handlers; this is what gets pushed.) +type Schedule struct { + ID string `json:"id"` + Kind JobKind `json:"kind"` + CronExpr string `json:"cron_expr"` + Paths []string `json:"paths,omitempty"` + Excludes []string `json:"excludes,omitempty"` + Tags []string `json:"tags,omitempty"` + RetentionPolicy json.RawMessage `json:"retention_policy,omitempty"` + Options json.RawMessage `json:"options,omitempty"` + PreHook string `json:"pre_hook,omitempty"` + PostHook string `json:"post_hook,omitempty"` + Enabled bool `json:"enabled"` +} + +// ScheduleSetPayload — server pushes the full canonical schedule list +// for a host. Agent reconciles its local cron and replies with +// ScheduleAckPayload carrying the same Version. +type ScheduleSetPayload struct { + Version int64 `json:"version"` + Schedules []Schedule `json:"schedules"` +} + +// ScheduleAckPayload — agent confirms it has applied a given version. +type ScheduleAckPayload struct { + Version int64 `json:"version"` + AppliedAt time.Time `json:"applied_at"` +} + +// ConfigUpdatePayload — server pushes per-host config (currently just +// repo connection details). Empty fields mean "leave existing alone"; +// to clear something, send an explicit zero value. +type ConfigUpdatePayload struct { + RepoURL string `json:"repo_url,omitempty"` + RepoPassword string `json:"repo_password,omitempty"` // sensitive + RepoUsername string `json:"repo_username,omitempty"` + RepoCredential string `json:"repo_credential,omitempty"` // sensitive (for rest server basic auth) + HookShell string `json:"hook_shell,omitempty"` +} + +// AgentUpdateAvailablePayload — informational only; the agent does +// NOT self-update. See spec.md §4.2 for the package-manager-based +// update model. +type AgentUpdateAvailablePayload struct { + LatestVersion string `json:"latest_version"` + PackageURL string `json:"package_url"` // apt repo / choco source + Changelog string `json:"changelog,omitempty"` +} diff --git a/internal/api/version.go b/internal/api/version.go new file mode 100644 index 0000000..1e9d1f5 --- /dev/null +++ b/internal/api/version.go @@ -0,0 +1,14 @@ +package api + +// CurrentProtocolVersion is the wire-format version this build speaks. +// +// Bump this only when an incompatible wire-format change lands — +// adding a new optional field does NOT count, removing or repurposing +// one does. The server enforces MinAgentProtocolVersion against this +// value at hello time. See spec.md §6.2 ("Protocol versioning"). +const CurrentProtocolVersion = 1 + +// MinAgentProtocolVersion is the lowest agent protocol_version this +// server accepts in a hello. Agents below this are disconnected with +// a structured error pointing at the upgrade docs. +const MinAgentProtocolVersion = 1 diff --git a/internal/api/wire.go b/internal/api/wire.go new file mode 100644 index 0000000..087baa8 --- /dev/null +++ b/internal/api/wire.go @@ -0,0 +1,86 @@ +package api + +import ( + "encoding/json" + "fmt" +) + +// MessageType enumerates every kind of envelope that can flow over +// the agent ↔ server WebSocket. Keeping these as string constants +// (not iota ints) makes traffic readable in logs and packet captures. +type MessageType string + +// Agent → server message types. +const ( + MsgHello MessageType = "hello" + MsgHeartbeat MessageType = "heartbeat" + MsgJobStarted MessageType = "job.started" + MsgJobProgress MessageType = "job.progress" + MsgJobFinished MessageType = "job.finished" + MsgSnapshotsRpt MessageType = "snapshots.report" + MsgRepoStats MessageType = "repo.stats" + MsgLogStream MessageType = "log.stream" + MsgScheduleAck MessageType = "schedule.ack" + MsgCommandResult MessageType = "command.result" // ack for command.run + MsgError MessageType = "error" +) + +// Server → agent message types. +const ( + MsgCommandRun MessageType = "command.run" + MsgCommandCancel MessageType = "command.cancel" + MsgScheduleSet MessageType = "schedule.set" + MsgConfigUpdate MessageType = "config.update" + MsgAgentUpdateAvail MessageType = "agent.update.available" +) + +// Envelope is the framing for every WS message in either direction. +// Payload is parsed into the concrete struct chosen by Type. +// +// ID is set on RPC-style messages (command.run / command.result) so +// responses can be correlated. For one-shot pushes (heartbeat, +// job.progress) it is empty. +type Envelope struct { + Type MessageType `json:"type"` + ID string `json:"id,omitempty"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +// Marshal builds an envelope from a concrete payload struct. +func Marshal(t MessageType, id string, payload any) (Envelope, error) { + if payload == nil { + return Envelope{Type: t, ID: id}, nil + } + raw, err := json.Marshal(payload) + if err != nil { + return Envelope{}, fmt.Errorf("marshal %s payload: %w", t, err) + } + return Envelope{Type: t, ID: id, Payload: raw}, nil +} + +// UnmarshalPayload decodes the envelope's payload into v. +func (e Envelope) UnmarshalPayload(v any) error { + if len(e.Payload) == 0 { + return nil + } + return json.Unmarshal(e.Payload, v) +} + +// ErrorCode enumerates error reasons surfaced over the wire. +// These are stable identifiers; client code may switch on them. +type ErrorCode string + +const ( + ErrProtocolTooOld ErrorCode = "protocol_too_old" + ErrProtocolTooNew ErrorCode = "protocol_too_new" + ErrUnauthorized ErrorCode = "unauthorized" + ErrBadRequest ErrorCode = "bad_request" + ErrInternal ErrorCode = "internal" +) + +// ErrorPayload is the body of an `error` envelope. +type ErrorPayload struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + HelpURL string `json:"help_url,omitempty"` +} diff --git a/internal/api/wire_test.go b/internal/api/wire_test.go new file mode 100644 index 0000000..095f2c6 --- /dev/null +++ b/internal/api/wire_test.go @@ -0,0 +1,143 @@ +package api + +import ( + "encoding/json" + "testing" + "time" +) + +func TestEnvelopeRoundTrip(t *testing.T) { + t.Parallel() + + hello := HelloPayload{ + ProtocolVersion: CurrentProtocolVersion, + AgentVersion: "0.1.0", + ResticVersion: "0.17.1", + Hostname: "test-host", + OS: OSLinux, + Arch: ArchAmd64, + } + + env, err := Marshal(MsgHello, "", hello) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + wire, err := json.Marshal(env) + if err != nil { + t.Fatalf("marshal envelope: %v", err) + } + + var decoded Envelope + if err := json.Unmarshal(wire, &decoded); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + if decoded.Type != MsgHello { + t.Errorf("type: got %q want %q", decoded.Type, MsgHello) + } + + var got HelloPayload + if err := decoded.UnmarshalPayload(&got); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if got != hello { + t.Errorf("round-trip mismatch: %+v != %+v", got, hello) + } +} + +func TestEnvelopeNilPayload(t *testing.T) { + t.Parallel() + + env, err := Marshal(MsgHeartbeat, "", nil) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if len(env.Payload) != 0 { + t.Errorf("nil payload should encode as empty, got %q", env.Payload) + } + // Unmarshalling nothing into anything must not error. + var hb HeartbeatPayload + if err := env.UnmarshalPayload(&hb); err != nil { + t.Errorf("unmarshal empty payload: %v", err) + } +} + +func TestEnvelopeRPCCorrelation(t *testing.T) { + t.Parallel() + + cmd := CommandRunPayload{JobID: "01HJ8K7", Kind: JobBackup} + env, err := Marshal(MsgCommandRun, "req-1", cmd) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if env.ID != "req-1" { + t.Errorf("id not preserved: %q", env.ID) + } + + res := CommandResultPayload{JobID: "01HJ8K7", Accepted: true} + resEnv, err := Marshal(MsgCommandResult, env.ID, res) + if err != nil { + t.Fatalf("marshal result: %v", err) + } + if resEnv.ID != env.ID { + t.Errorf("rpc id mismatch: req=%q res=%q", env.ID, resEnv.ID) + } +} + +func TestErrorPayload(t *testing.T) { + t.Parallel() + + ep := ErrorPayload{ + Code: ErrProtocolTooOld, + Message: "agent protocol_version 0 below minimum 1", + HelpURL: "https://example.com/upgrade", + } + env, err := Marshal(MsgError, "", ep) + if err != nil { + t.Fatalf("marshal: %v", err) + } + wire, _ := json.Marshal(env) + + var decoded Envelope + if err := json.Unmarshal(wire, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + var got ErrorPayload + if err := decoded.UnmarshalPayload(&got); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if got.Code != ErrProtocolTooOld { + t.Errorf("code: got %q want %q", got.Code, ErrProtocolTooOld) + } +} + +func TestProtocolVersionConstants(t *testing.T) { + t.Parallel() + + if CurrentProtocolVersion < 1 { + t.Errorf("CurrentProtocolVersion must be >= 1, got %d", CurrentProtocolVersion) + } + if MinAgentProtocolVersion > CurrentProtocolVersion { + t.Errorf("min %d > current %d — server would refuse all agents", + MinAgentProtocolVersion, CurrentProtocolVersion) + } +} + +func TestJobProgressShapeStable(t *testing.T) { + t.Parallel() + // Locks the JSON field names from spec.md §6.2 so a rename here + // breaks tests instead of silently breaking the agent. + p := JobProgressPayload{ + JobID: "j", PercentDone: 0.5, FilesDone: 1, TotalFiles: 2, + BytesDone: 100, TotalBytes: 200, ETASeconds: 30, ThroughputBps: 1000, + } + raw, _ := json.Marshal(p) + want := `{"job_id":"j","percent_done":0.5,"files_done":1,"total_files":2,"bytes_done":100,"total_bytes":200,"eta_seconds":30,"throughput_bps":1000}` + if string(raw) != want { + t.Errorf("wire shape drifted:\n got %s\n want %s", raw, want) + } +} + +// touch time so the import is used by other tests in this file when +// they grow over time. +var _ = time.Now diff --git a/internal/auth/doc.go b/internal/auth/doc.go deleted file mode 100644 index 4ff4536..0000000 --- a/internal/auth/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package auth handles password hashing (argon2id), session cookies, -// CSRF tokens, and bearer-token verification for agents. -package auth diff --git a/internal/auth/passwords.go b/internal/auth/passwords.go new file mode 100644 index 0000000..dd26567 --- /dev/null +++ b/internal/auth/passwords.go @@ -0,0 +1,87 @@ +// Package auth handles password hashing (argon2id), session +// management, CSRF tokens, and bearer-token verification for agents. +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +// argon2id parameters following RFC 9106 §4 "second +// recommended option" (memory-constrained): +// - 64 MiB memory, 3 iterations, 4 lanes, 32-byte tag. +// These are tunable per-deployment if a beefy controller wants to +// crank them; we ship a defensible default. +const ( + defaultMemoryKiB = 64 * 1024 + defaultIterations = 3 + defaultParallel = 4 + defaultSaltLen = 16 + defaultKeyLen = 32 +) + +// HashPassword returns an argon2id-encoded string of the form +// $argon2id$v=19$m=...,t=...,p=...$$ +// safe to store in a TEXT column. The salt is freshly random per call. +func HashPassword(password string) (string, error) { + salt := make([]byte, defaultSaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("auth: read salt: %w", err) + } + hash := argon2.IDKey([]byte(password), salt, + defaultIterations, defaultMemoryKiB, defaultParallel, defaultKeyLen) + + return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, + defaultMemoryKiB, defaultIterations, defaultParallel, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(hash), + ), nil +} + +// VerifyPassword returns nil if password matches the encoded hash. +// On any decode error or mismatch the error is non-nil — callers +// should treat all non-nil returns as "invalid credentials" and not +// leak which case it was. +func VerifyPassword(encoded, password string) error { + parts := strings.Split(encoded, "$") + // "$argon2id$v=...$m=...,t=...,p=...$$" → 6 parts (leading empty) + if len(parts) != 6 || parts[1] != "argon2id" { + return errors.New("auth: unrecognised hash format") + } + var version int + if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil { + return fmt.Errorf("auth: parse version: %w", err) + } + if version != argon2.Version { + return fmt.Errorf("auth: unsupported argon2 version %d", version) + } + var memory, iterations uint32 + var parallel uint8 + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", + &memory, &iterations, ¶llel); err != nil { + return fmt.Errorf("auth: parse params: %w", err) + } + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return fmt.Errorf("auth: decode salt: %w", err) + } + want, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return fmt.Errorf("auth: decode hash: %w", err) + } + + got := argon2.IDKey([]byte(password), salt, + iterations, memory, parallel, uint32(len(want))) + + if subtle.ConstantTimeCompare(got, want) != 1 { + return errors.New("auth: invalid password") + } + return nil +} diff --git a/internal/auth/passwords_test.go b/internal/auth/passwords_test.go new file mode 100644 index 0000000..b7f82bd --- /dev/null +++ b/internal/auth/passwords_test.go @@ -0,0 +1,81 @@ +package auth + +import ( + "strings" + "testing" +) + +func TestHashAndVerify(t *testing.T) { + t.Parallel() + + pw := "correct horse battery staple" + h, err := HashPassword(pw) + if err != nil { + t.Fatalf("hash: %v", err) + } + if !strings.HasPrefix(h, "$argon2id$") { + t.Errorf("encoded form should start $argon2id$, got %q", h) + } + if err := VerifyPassword(h, pw); err != nil { + t.Errorf("verify: %v", err) + } + if err := VerifyPassword(h, "wrong"); err == nil { + t.Error("verify with wrong password should fail") + } +} + +func TestEachHashIsUnique(t *testing.T) { + t.Parallel() + // Same password hashed twice → different encoded strings (different + // salts). If this fails the salt is deterministic. + a, _ := HashPassword("hunter2") + b, _ := HashPassword("hunter2") + if a == b { + t.Fatal("two hashes of the same password collided — non-random salt?") + } +} + +func TestVerifyRejectsMalformed(t *testing.T) { + t.Parallel() + cases := []string{ + "", + "not-a-hash", + "$argon2i$v=19$m=64,t=3,p=4$AAAA$BBBB", // wrong variant + "$argon2id$", // truncated + "$argon2id$v=99$m=64,t=3,p=4$AAAA$BBBB", // bad version + } + for _, c := range cases { + if err := VerifyPassword(c, "anything"); err == nil { + t.Errorf("should reject malformed hash %q", c) + } + } +} + +func TestNewTokenUnique(t *testing.T) { + t.Parallel() + a, err := NewToken() + if err != nil { + t.Fatalf("token: %v", err) + } + b, _ := NewToken() + if a == b { + t.Fatal("two tokens collided — broken randomness") + } + if len(a) < 40 { + t.Errorf("token suspiciously short: %q (%d bytes)", a, len(a)) + } +} + +func TestHashTokenStable(t *testing.T) { + t.Parallel() + // Same input → same hash. This is not a security property, just a + // sanity check that we're using a regular hash not a salted one. + h1 := HashToken("foo") + h2 := HashToken("foo") + if h1 != h2 { + t.Errorf("HashToken not deterministic: %q vs %q", h1, h2) + } + if len(h1) != 64 { // sha256 hex + t.Errorf("expected 64-char hex hash, got %d", len(h1)) + } +} diff --git a/internal/auth/tokens.go b/internal/auth/tokens.go new file mode 100644 index 0000000..434bfe3 --- /dev/null +++ b/internal/auth/tokens.go @@ -0,0 +1,34 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" +) + +// TokenLen is the number of random bytes in session, CSRF, and +// enrollment tokens. 32 bytes = 256 bits of entropy, more than enough +// to be unguessable. +const TokenLen = 32 + +// NewToken returns a fresh URL-safe random token. Used for session +// IDs, CSRF tokens, agent bearer tokens, and one-time enrollment +// tokens. Returns base64url(no-padding) for compactness. +func NewToken() (string, error) { + buf := make([]byte, TokenLen) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("auth: read random: %w", err) + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +// HashToken returns a hex-encoded SHA-256 of the token. We store +// this rather than the raw token so a stolen DB doesn't yield +// session/agent credentials directly. SHA-256 (not argon2) is fine +// here because the input is already 256 bits of uniform random. +func HashToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/crypto/aead.go b/internal/crypto/aead.go new file mode 100644 index 0000000..4564d30 --- /dev/null +++ b/internal/crypto/aead.go @@ -0,0 +1,112 @@ +// Package crypto wraps AEAD encryption used to protect repo +// passwords, REST-server credentials, hook bodies, and any other +// secret that lands in the SQLite store. +// +// The threat model is "defense in depth against a stolen DB file" — +// not "an attacker with code execution can't read secrets at runtime." +// We need the encryption key at runtime to do any actual work, so +// anyone with a memory dump of the running server can extract it. +package crypto + +import ( + stdcipher "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "os" + + "golang.org/x/crypto/chacha20poly1305" +) + +// KeyLen is the required length of the master key (XChaCha20-Poly1305 +// uses a 32-byte key). Keys shorter than this are rejected at load. +const KeyLen = chacha20poly1305.KeySize // 32 + +// AEAD wraps an XChaCha20-Poly1305 instance with a 24-byte random +// nonce per message. Ciphertexts are encoded as +// base64(nonce || ciphertext_with_tag) for SQLite storage. +type AEAD struct { + cipher stdcipher.AEAD +} + +// NewAEAD returns an AEAD using the given 32-byte key. +func NewAEAD(key []byte) (*AEAD, error) { + if len(key) != KeyLen { + return nil, fmt.Errorf("crypto: key must be %d bytes, got %d", KeyLen, len(key)) + } + c, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, fmt.Errorf("crypto: init xchacha20poly1305: %w", err) + } + return &AEAD{cipher: c}, nil +} + +// LoadKeyFromFile reads a 32-byte raw key from path. The file must +// be exactly KeyLen bytes long. Use GenerateKeyFile to mint a fresh +// one on first run. +func LoadKeyFromFile(path string) ([]byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read key file %q: %w", path, err) + } + if len(data) != KeyLen { + return nil, fmt.Errorf("key file %q: expected %d bytes, got %d", + path, KeyLen, len(data)) + } + return data, nil +} + +// GenerateKeyFile writes a new 32-byte random key to path with mode +// 0600. It refuses to overwrite an existing file. +func GenerateKeyFile(path string) error { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600) + if err != nil { + return fmt.Errorf("create key file %q: %w", path, err) + } + defer f.Close() + key := make([]byte, KeyLen) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return fmt.Errorf("read random: %w", err) + } + if _, err := f.Write(key); err != nil { + return fmt.Errorf("write key: %w", err) + } + return nil +} + +// Encrypt seals plaintext under a fresh random nonce. The returned +// string is base64(nonce || ciphertext_with_tag) and is what gets +// stored in TEXT columns. Optional additionalData binds the +// ciphertext to a context (e.g. the row's primary key) so a swap +// attack between rows is detectable. +func (a *AEAD) Encrypt(plaintext, additionalData []byte) (string, error) { + nonce := make([]byte, a.cipher.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("crypto: read nonce: %w", err) + } + ct := a.cipher.Seal(nil, nonce, plaintext, additionalData) + out := make([]byte, 0, len(nonce)+len(ct)) + out = append(out, nonce...) + out = append(out, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +// Decrypt reverses Encrypt. +func (a *AEAD) Decrypt(ciphertext string, additionalData []byte) ([]byte, error) { + raw, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return nil, fmt.Errorf("crypto: base64 decode: %w", err) + } + if len(raw) < a.cipher.NonceSize()+a.cipher.Overhead() { + return nil, errors.New("crypto: ciphertext too short") + } + nonce := raw[:a.cipher.NonceSize()] + ct := raw[a.cipher.NonceSize():] + pt, err := a.cipher.Open(nil, nonce, ct, additionalData) + if err != nil { + return nil, fmt.Errorf("crypto: open: %w", err) + } + return pt, nil +} diff --git a/internal/crypto/aead_test.go b/internal/crypto/aead_test.go new file mode 100644 index 0000000..60cd36f --- /dev/null +++ b/internal/crypto/aead_test.go @@ -0,0 +1,110 @@ +package crypto + +import ( + "bytes" + "crypto/rand" + "path/filepath" + "testing" +) + +func TestRoundTrip(t *testing.T) { + t.Parallel() + + key := make([]byte, KeyLen) + if _, err := rand.Read(key); err != nil { + t.Fatalf("rand: %v", err) + } + a, err := NewAEAD(key) + if err != nil { + t.Fatalf("new: %v", err) + } + + plaintext := []byte("super-secret-restic-password") + ad := []byte("repos/01HJ8K7/password") + + ct, err := a.Encrypt(plaintext, ad) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + if ct == "" { + t.Fatal("ciphertext empty") + } + pt, err := a.Decrypt(ct, ad) + if err != nil { + t.Fatalf("decrypt: %v", err) + } + if !bytes.Equal(pt, plaintext) { + t.Errorf("round-trip mismatch: got %q want %q", pt, plaintext) + } +} + +func TestADMismatchFails(t *testing.T) { + t.Parallel() + + key := make([]byte, KeyLen) + _, _ = rand.Read(key) + a, _ := NewAEAD(key) + + ct, _ := a.Encrypt([]byte("secret"), []byte("context-A")) + if _, err := a.Decrypt(ct, []byte("context-B")); err == nil { + t.Fatal("expected AD-mismatch failure, got nil") + } +} + +func TestNonceUniqueness(t *testing.T) { + t.Parallel() + + key := make([]byte, KeyLen) + _, _ = rand.Read(key) + a, _ := NewAEAD(key) + + // Same plaintext + AD must produce different ciphertexts because + // we use a random nonce per call. If this ever fails the AEAD is + // broken or someone made the nonce deterministic. + ct1, _ := a.Encrypt([]byte("x"), nil) + ct2, _ := a.Encrypt([]byte("x"), nil) + if ct1 == ct2 { + t.Fatal("two encryptions produced identical ciphertext — nonce reuse") + } +} + +func TestKeyFileLifecycle(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "secret.key") + + if err := GenerateKeyFile(path); err != nil { + t.Fatalf("generate: %v", err) + } + // Refusal-to-overwrite is the safety property — a re-run of the + // server must not silently swap the key. + if err := GenerateKeyFile(path); err == nil { + t.Fatal("expected refusal to overwrite, got nil") + } + + key, err := LoadKeyFromFile(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if len(key) != KeyLen { + t.Errorf("key length: got %d want %d", len(key), KeyLen) + } +} + +func TestRejectShortKey(t *testing.T) { + t.Parallel() + if _, err := NewAEAD(make([]byte, KeyLen-1)); err == nil { + t.Fatal("expected short-key rejection, got nil") + } +} + +func TestRejectShortCiphertext(t *testing.T) { + t.Parallel() + key := make([]byte, KeyLen) + _, _ = rand.Read(key) + a, _ := NewAEAD(key) + if _, err := a.Decrypt("AAAA", nil); err == nil { + t.Fatal("expected short-ciphertext rejection, got nil") + } +} diff --git a/internal/crypto/doc.go b/internal/crypto/doc.go deleted file mode 100644 index d295b21..0000000 --- a/internal/crypto/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package crypto wraps AEAD encryption used to protect repo passwords, -// REST-server credentials, and pre/post hook bodies at rest. -package crypto diff --git a/internal/store/audit.go b/internal/store/audit.go new file mode 100644 index 0000000..b5cce69 --- /dev/null +++ b/internal/store/audit.go @@ -0,0 +1,36 @@ +package store + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +// AppendAudit records an audit log entry. +func (s *Store) AppendAudit(ctx context.Context, e AuditEntry) error { + if len(e.Payload) == 0 { + e.Payload = json.RawMessage("{}") + } + _, err := s.db.ExecContext(ctx, + `INSERT INTO audit_log (id, user_id, actor, action, target_kind, target_id, ts, payload) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + e.ID, nullable(e.UserID), e.Actor, e.Action, + nullable(e.TargetKind), nullable(e.TargetID), + e.TS.UTC().Format(time.RFC3339Nano), + string(e.Payload)) + if err != nil { + return fmt.Errorf("store: append audit: %w", err) + } + return nil +} + +// nullable returns nil for nil/empty *string so SQLite stores NULL. +// SQLite's driver treats Go nil as NULL but treats *string("") as ''. +// We want NULL semantics for "absent." +func nullable(p *string) any { + if p == nil || *p == "" { + return nil + } + return *p +} diff --git a/internal/store/doc.go b/internal/store/doc.go deleted file mode 100644 index fce290b..0000000 --- a/internal/store/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package store is the SQLite persistence layer -// (modernc.org/sqlite, no CGo). -package store diff --git a/internal/store/enrollment.go b/internal/store/enrollment.go new file mode 100644 index 0000000..c054af9 --- /dev/null +++ b/internal/store/enrollment.go @@ -0,0 +1,58 @@ +package store + +import ( + "context" + "fmt" + "time" +) + +// CreateEnrollmentToken persists a fresh one-time token. The caller +// has already hashed the raw token; the raw form is returned to the +// operator (printed in the install snippet) and never persisted. +func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration) error { + now := time.Now().UTC() + _, err := s.db.ExecContext(ctx, + `INSERT INTO enrollment_tokens (token_hash, created_at, expires_at) + VALUES (?, ?, ?)`, + tokenHash, + now.Format(time.RFC3339Nano), + now.Add(ttl).Format(time.RFC3339Nano)) + if err != nil { + return fmt.Errorf("store: create enrollment token: %w", err) + } + return nil +} + +// ConsumeEnrollmentToken atomically validates a token (must exist, +// not be consumed, not be expired) and marks it consumed by hostID. +// Returns ErrNotFound on any failure. +func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + res, err := s.db.ExecContext(ctx, + `UPDATE enrollment_tokens + SET consumed_at = ?, consumed_host = ? + WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`, + now, hostID, tokenHash, now) + if err != nil { + return fmt.Errorf("store: consume enrollment token: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return ErrNotFound + } + return nil +} + +// PurgeExpiredEnrollmentTokens deletes long-expired token rows. Tokens +// retained for ~24h after expiry so audit traces still resolve them. +func (s *Store) PurgeExpiredEnrollmentTokens(ctx context.Context) (int64, error) { + cutoff := time.Now().Add(-24 * time.Hour).UTC().Format(time.RFC3339Nano) + res, err := s.db.ExecContext(ctx, + `DELETE FROM enrollment_tokens WHERE expires_at <= ?`, cutoff) + if err != nil { + return 0, fmt.Errorf("store: purge enrollment tokens: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} + diff --git a/internal/store/migrate.go b/internal/store/migrate.go new file mode 100644 index 0000000..0fd80bb --- /dev/null +++ b/internal/store/migrate.go @@ -0,0 +1,100 @@ +package store + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "sort" + "strings" +) + +//go:embed migrations/*.sql +var migrationsFS embed.FS + +// migration is one ordered SQL file from migrations/. +type migration struct { + version int // parsed from filename prefix (0001, 0002, …) + name string // full filename, for error messages + sql string +} + +// loadMigrations reads every migrations/*.sql file in lexical order +// and returns them. Filenames must look like NNNN_name.sql; the +// numeric prefix is the version. +func loadMigrations() ([]migration, error) { + entries, err := fs.ReadDir(migrationsFS, "migrations") + if err != nil { + return nil, fmt.Errorf("read migrations dir: %w", err) + } + out := make([]migration, 0, len(entries)) + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") { + continue + } + var v int + // Allow up to 6 digits (we will never need that many but it + // costs nothing to be permissive). + if _, err := fmt.Sscanf(e.Name(), "%d_", &v); err != nil { + return nil, fmt.Errorf("migration %q: cannot parse version prefix: %w", e.Name(), err) + } + body, err := fs.ReadFile(migrationsFS, "migrations/"+e.Name()) + if err != nil { + return nil, fmt.Errorf("read %s: %w", e.Name(), err) + } + out = append(out, migration{version: v, name: e.Name(), sql: string(body)}) + } + sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version }) + return out, nil +} + +// migrate brings the db up to the highest known version. It is +// idempotent: running it on an already-current db is a no-op. There +// is no rollback path; we move forward only. +func migrate(ctx context.Context, db *sql.DB) error { + if _, err := db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY, + applied_at TEXT NOT NULL + ) + `); err != nil { + return fmt.Errorf("create schema_version: %w", err) + } + + migs, err := loadMigrations() + if err != nil { + return err + } + + for _, m := range migs { + var applied int + row := db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM schema_version WHERE version = ?`, m.version) + if err := row.Scan(&applied); err != nil { + return fmt.Errorf("check version %d: %w", m.version, err) + } + if applied > 0 { + continue + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx for migration %s: %w", m.name, err) + } + if _, err := tx.ExecContext(ctx, m.sql); err != nil { + _ = tx.Rollback() + return fmt.Errorf("apply %s: %w", m.name, err) + } + if _, err := tx.ExecContext(ctx, + `INSERT INTO schema_version (version, applied_at) VALUES (?, datetime('now'))`, + m.version); err != nil { + _ = tx.Rollback() + return fmt.Errorf("record %s: %w", m.name, err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit %s: %w", m.name, err) + } + } + return nil +} diff --git a/internal/store/migrations/0001_initial.sql b/internal/store/migrations/0001_initial.sql new file mode 100644 index 0000000..746e906 --- /dev/null +++ b/internal/store/migrations/0001_initial.sql @@ -0,0 +1,199 @@ +-- 0001_initial.sql +-- +-- Initial schema for restic-manager. Mirrors the domain model in +-- spec.md §5. We use TEXT primary keys (ULIDs) throughout: sortable, +-- URL-safe, no autoincrement contention. JSON blobs are stored as +-- TEXT; SQLite's json1 extension is available but we read/write +-- raw and parse in Go for portability. +-- +-- All timestamps are stored as RFC 3339 TEXT (UTC). SQLite's INTEGER +-- (unix epoch) would be cheaper but text is human-readable in dumps +-- and the storage cost is negligible at this scale. + +CREATE TABLE users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('admin','operator','viewer')), + created_at TEXT NOT NULL, + last_login_at TEXT +); + +CREATE TABLE sessions ( + id TEXT PRIMARY KEY, -- session token (high-entropy) + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + ip TEXT, + ua TEXT +); +CREATE INDEX sessions_user_id ON sessions(user_id); +CREATE INDEX sessions_expires_at ON sessions(expires_at); + +CREATE TABLE credentials ( + id TEXT PRIMARY KEY, + kind TEXT NOT NULL, -- 'rest','s3','local' + username TEXT, + -- secret_ref is the AEAD ciphertext (nonce || ciphertext, base64). + -- The plaintext never lands on disk. + secret_ref TEXT NOT NULL, + rotated_at TEXT NOT NULL +); + +CREATE TABLE repos ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + url TEXT NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('rest','s3','local')), + credential_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT, + password_secret_id TEXT REFERENCES credentials(id) ON DELETE RESTRICT, + -- Cached projection from `restic stats` + lock-file inspection. + size_bytes INTEGER NOT NULL DEFAULT 0, + snapshot_count INTEGER NOT NULL DEFAULT 0, + dedup_ratio REAL NOT NULL DEFAULT 0, + last_check_at TEXT, + last_check_status TEXT, + lock_state TEXT NOT NULL DEFAULT 'unlocked' + CHECK (lock_state IN ('locked','unlocked')), + append_only INTEGER NOT NULL DEFAULT 1, -- bool + credential_rotated_at TEXT +); + +CREATE TABLE hosts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + os TEXT NOT NULL, + arch TEXT NOT NULL, + agent_version TEXT NOT NULL DEFAULT '', + restic_version TEXT NOT NULL DEFAULT '', + protocol_version INTEGER NOT NULL DEFAULT 0, + enrolled_at TEXT NOT NULL, + last_seen_at TEXT, + status TEXT NOT NULL DEFAULT 'offline' + CHECK (status IN ('online','offline','degraded')), + repo_id TEXT REFERENCES repos(id) ON DELETE SET NULL, + tags TEXT NOT NULL DEFAULT '[]', -- json array + current_job_id TEXT, + -- Denormalised projections (refreshed on job.finished etc). + last_backup_at TEXT, + last_backup_status TEXT + CHECK (last_backup_status IN + ('succeeded','failed','cancelled') OR + last_backup_status IS NULL), + repo_size_bytes INTEGER NOT NULL DEFAULT 0, + snapshot_count INTEGER NOT NULL DEFAULT 0, + open_alert_count INTEGER NOT NULL DEFAULT 0, + applied_schedule_version INTEGER NOT NULL DEFAULT 0, + -- Server-issued credentials for the agent ↔ server WS. + agent_token_hash TEXT NOT NULL DEFAULT '', + cert_pin_sha256 TEXT NOT NULL DEFAULT '' +); +CREATE INDEX hosts_status ON hosts(status); +CREATE INDEX hosts_last_seen_at ON hosts(last_seen_at); + +-- Pending one-time enrollment tokens (TTL'd, single-use). +CREATE TABLE enrollment_tokens ( + token_hash TEXT PRIMARY KEY, -- argon2id of token + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + consumed_at TEXT, + consumed_host TEXT REFERENCES hosts(id) ON DELETE SET NULL +); +CREATE INDEX enrollment_tokens_expires_at ON enrollment_tokens(expires_at); + +CREATE TABLE schedules ( + id TEXT PRIMARY KEY, + host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check')), + cron_expr TEXT NOT NULL, + paths TEXT NOT NULL DEFAULT '[]', -- json array + excludes TEXT NOT NULL DEFAULT '[]', + tags TEXT NOT NULL DEFAULT '[]', + retention_policy TEXT NOT NULL DEFAULT '{}', -- json object + options TEXT NOT NULL DEFAULT '{}', -- json object (bandwidth) + -- Hooks are encrypted at rest (AEAD ciphertext). Constraint enforced + -- in application code: hooks must be empty unless kind='backup'. + pre_hook TEXT NOT NULL DEFAULT '', + post_hook TEXT NOT NULL DEFAULT '', + enabled INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); +CREATE INDEX schedules_host_id ON schedules(host_id); + +-- Per-host monotonic schedule version. Bumped on any schedules INSERT/ +-- UPDATE/DELETE for that host. Pushed to the agent in schedule.set; +-- the agent acks back the same version in schedule.ack. +CREATE TABLE host_schedule_version ( + host_id TEXT PRIMARY KEY REFERENCES hosts(id) ON DELETE CASCADE, + version INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE jobs ( + id TEXT PRIMARY KEY, + host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + kind TEXT NOT NULL CHECK (kind IN ('backup','forget','prune','check','unlock')), + status TEXT NOT NULL CHECK (status IN ('queued','running','succeeded','failed','cancelled')), + scheduled_id TEXT REFERENCES schedules(id) ON DELETE SET NULL, + actor_kind TEXT NOT NULL CHECK (actor_kind IN ('user','schedule','system')), + actor_id TEXT, -- user id, schedule id, or null + started_at TEXT, + finished_at TEXT, + exit_code INTEGER, + stats TEXT, -- json blob from restic + error TEXT, + created_at TEXT NOT NULL +); +CREATE INDEX jobs_host_id ON jobs(host_id); +CREATE INDEX jobs_status ON jobs(status); +CREATE INDEX jobs_created_at ON jobs(created_at); + +CREATE TABLE job_logs ( + job_id TEXT NOT NULL REFERENCES jobs(id) ON DELETE CASCADE, + seq INTEGER NOT NULL, + ts TEXT NOT NULL, + stream TEXT NOT NULL CHECK (stream IN ('stdout','stderr','event')), + payload TEXT NOT NULL, + PRIMARY KEY (job_id, seq) +); + +CREATE TABLE snapshots ( + id TEXT PRIMARY KEY, -- restic snapshot id + host_id TEXT NOT NULL REFERENCES hosts(id) ON DELETE CASCADE, + repo_id TEXT NOT NULL REFERENCES repos(id) ON DELETE CASCADE, + time TEXT NOT NULL, + hostname TEXT NOT NULL, + paths TEXT NOT NULL DEFAULT '[]', + tags TEXT NOT NULL DEFAULT '[]', + size_bytes INTEGER NOT NULL DEFAULT 0, + file_count INTEGER NOT NULL DEFAULT 0 +); +CREATE INDEX snapshots_host_id ON snapshots(host_id); +CREATE INDEX snapshots_time ON snapshots(time); + +CREATE TABLE alerts ( + id TEXT PRIMARY KEY, + host_id TEXT REFERENCES hosts(id) ON DELETE CASCADE, + kind TEXT NOT NULL, + severity TEXT NOT NULL CHECK (severity IN ('info','warning','critical')), + message TEXT NOT NULL, + created_at TEXT NOT NULL, + acknowledged_at TEXT, + acknowledged_by TEXT REFERENCES users(id) ON DELETE SET NULL, + resolved_at TEXT +); +CREATE INDEX alerts_host_id ON alerts(host_id); +CREATE INDEX alerts_open ON alerts(host_id) WHERE resolved_at IS NULL; + +CREATE TABLE audit_log ( + id TEXT PRIMARY KEY, + user_id TEXT REFERENCES users(id) ON DELETE SET NULL, + actor TEXT NOT NULL CHECK (actor IN ('user','agent','system')), + action TEXT NOT NULL, + target_kind TEXT, + target_id TEXT, + ts TEXT NOT NULL, + payload TEXT NOT NULL DEFAULT '{}' +); +CREATE INDEX audit_log_ts ON audit_log(ts); +CREATE INDEX audit_log_user ON audit_log(user_id); diff --git a/internal/store/sessions.go b/internal/store/sessions.go new file mode 100644 index 0000000..df26831 --- /dev/null +++ b/internal/store/sessions.go @@ -0,0 +1,88 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" +) + +// CreateSession persists a session row. The token is hashed before +// insert; the raw token is what the caller hands to the user (cookie). +func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error { + _, err := s.db.ExecContext(ctx, + `INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua) + VALUES (?, ?, ?, ?, ?, ?)`, + tokenHash, + sess.UserID, + sess.CreatedAt.UTC().Format(time.RFC3339Nano), + sess.ExpiresAt.UTC().Format(time.RFC3339Nano), + sess.IP, sess.UA) + if err != nil { + return fmt.Errorf("store: create session: %w", err) + } + return nil +} + +// LookupSession resolves a token hash to a session row, returning +// ErrNotFound if the hash is unknown OR the session has expired. +// We collapse "no row" and "expired" to the same error so the caller +// can't tell them apart in error messages — that prevents enumeration +// of valid token hashes. +func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, user_id, created_at, expires_at, ip, ua + FROM sessions + WHERE id = ? AND expires_at > ?`, + tokenHash, time.Now().UTC().Format(time.RFC3339Nano)) + + var sess Session + var created, expires string + var ip, ua sql.NullString + if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("store: lookup session: %w", err) + } + t, err := time.Parse(time.RFC3339Nano, created) + if err != nil { + return nil, fmt.Errorf("store: parse created_at: %w", err) + } + sess.CreatedAt = t + t, err = time.Parse(time.RFC3339Nano, expires) + if err != nil { + return nil, fmt.Errorf("store: parse expires_at: %w", err) + } + sess.ExpiresAt = t + if ip.Valid { + sess.IP = ip.String + } + if ua.Valid { + sess.UA = ua.String + } + return &sess, nil +} + +// DeleteSession removes a session row by token hash. Used on logout. +func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, tokenHash) + if err != nil { + return fmt.Errorf("store: delete session: %w", err) + } + return nil +} + +// PurgeExpiredSessions deletes session rows past their expires_at. +// Run periodically from a background goroutine. +func (s *Store) PurgeExpiredSessions(ctx context.Context) (int64, error) { + res, err := s.db.ExecContext(ctx, + `DELETE FROM sessions WHERE expires_at <= ?`, + time.Now().UTC().Format(time.RFC3339Nano)) + if err != nil { + return 0, fmt.Errorf("store: purge sessions: %w", err) + } + n, _ := res.RowsAffected() + return n, nil +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..e7414d7 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,84 @@ +// Package store is the SQLite persistence layer (modernc.org/sqlite, +// no CGo). It owns the schema, exposes typed accessors, and hides +// the database/sql plumbing from the rest of the server. +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/url" + "time" + + _ "modernc.org/sqlite" // register the "sqlite" driver +) + +// ErrNotFound is returned by accessors when a lookup misses. +var ErrNotFound = errors.New("store: not found") + +// Store is a thin wrapper around *sql.DB that exposes the typed +// accessors used by the rest of the server. Callers should use the +// provided methods rather than reaching into DB() directly. +type Store struct { + db *sql.DB +} + +// Open opens (or creates) the SQLite database at path, applies all +// pending migrations, and returns a ready-to-use Store. +// +// The DSN sets: +// - _pragma=foreign_keys(1) — referential integrity is on +// - _pragma=journal_mode(WAL) — concurrent reads vs writes +// - _pragma=busy_timeout(5000) — wait 5s on lock contention +// - _time_format=sqlite — RFC 3339 read/write of TEXT timestamps +// +// Empty path uses an in-memory DB (useful for tests). +func Open(ctx context.Context, path string) (*Store, error) { + dsn := buildDSN(path) + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("open %q: %w", path, err) + } + // modernc.org/sqlite is not safe for arbitrary high parallelism on + // a single file. WAL helps, but 1 writer + multiple readers is the + // only safe shape. Cap connections to keep that property explicit. + db.SetMaxOpenConns(8) + db.SetMaxIdleConns(4) + db.SetConnMaxLifetime(time.Hour) + + pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := db.PingContext(pingCtx); err != nil { + _ = db.Close() + return nil, fmt.Errorf("ping: %w", err) + } + + if err := migrate(ctx, db); err != nil { + _ = db.Close() + return nil, fmt.Errorf("migrate: %w", err) + } + + return &Store{db: db}, nil +} + +// Close releases the underlying DB handle. +func (s *Store) Close() error { return s.db.Close() } + +// DB returns the underlying *sql.DB. Reserved for tests and migrations +// — production code should add a typed method to this package instead. +func (s *Store) DB() *sql.DB { return s.db } + +func buildDSN(path string) string { + if path == "" { + // Shared cache + named in-memory db so multiple connections see + // the same data — needed because we cap MaxOpenConns above. + return "file::memory:?cache=shared&_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)" + } + q := url.Values{} + q.Add("_pragma", "foreign_keys(1)") + q.Add("_pragma", "journal_mode(WAL)") + q.Add("_pragma", "busy_timeout(5000)") + q.Add("_pragma", "synchronous(NORMAL)") + return "file:" + path + "?" + q.Encode() +} diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..ff4fe2b --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,93 @@ +package store + +import ( + "context" + "path/filepath" + "testing" +) + +// openTestStore opens an isolated file-backed db in a t.TempDir. +// In-memory + shared-cache works too but file makes failures easier +// to inspect when a test panics. +func openTestStore(t *testing.T) *Store { + t.Helper() + dir := t.TempDir() + s, err := Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = s.Close() }) + return s +} + +func TestOpenAppliesMigrations(t *testing.T) { + t.Parallel() + s := openTestStore(t) + + row := s.DB().QueryRow(`SELECT MAX(version) FROM schema_version`) + var v int + if err := row.Scan(&v); err != nil { + t.Fatalf("scan: %v", err) + } + if v < 1 { + t.Fatalf("expected at least migration 1 applied, got %d", v) + } + + // Spot-check a few tables exist with expected columns. + tables := []string{"users", "sessions", "hosts", "repos", + "credentials", "schedules", "jobs", "job_logs", + "snapshots", "alerts", "audit_log", + "enrollment_tokens", "host_schedule_version"} + for _, tbl := range tables { + row := s.DB().QueryRow( + `SELECT name FROM sqlite_master WHERE type='table' AND name = ?`, tbl) + var got string + if err := row.Scan(&got); err != nil { + t.Errorf("table %q missing: %v", tbl, err) + } + } +} + +func TestMigrateIsIdempotent(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "rm.db") + + for i := 0; i < 3; i++ { + s, err := Open(context.Background(), path) + if err != nil { + t.Fatalf("open #%d: %v", i, err) + } + _ = s.Close() + } + + s, err := Open(context.Background(), path) + if err != nil { + t.Fatalf("final open: %v", err) + } + defer s.Close() + + row := s.DB().QueryRow(`SELECT COUNT(*) FROM schema_version`) + var n int + if err := row.Scan(&n); err != nil { + t.Fatalf("scan: %v", err) + } + if n != 1 { + t.Errorf("re-running migrations should not insert duplicate rows; got %d", n) + } +} + +func TestForeignKeysEnforced(t *testing.T) { + t.Parallel() + s := openTestStore(t) + + // Inserting a session with a non-existent user should fail because + // FKs are on. Without the pragma, SQLite silently accepts this. + _, err := s.DB().Exec( + `INSERT INTO sessions (id, user_id, created_at, expires_at) + VALUES (?, ?, datetime('now'), datetime('now','+1 hour'))`, + "sess1", "no-such-user") + if err == nil { + t.Fatal("expected FK violation, got nil") + } +} diff --git a/internal/store/types.go b/internal/store/types.go new file mode 100644 index 0000000..251f5f0 --- /dev/null +++ b/internal/store/types.go @@ -0,0 +1,82 @@ +package store + +import ( + "encoding/json" + "time" +) + +// User mirrors the users table. +type User struct { + ID string + Username string + PasswordHash string + Role Role + CreatedAt time.Time + LastLoginAt *time.Time +} + +// Role enumerates the access tiers from spec.md §7.2. +type Role string + +const ( + RoleAdmin Role = "admin" + RoleOperator Role = "operator" + RoleViewer Role = "viewer" +) + +// Session mirrors the sessions table. The ID is the (raw) session +// token; the DB stores its hash. Callers that hold a *Session have +// already authenticated. +type Session struct { + ID string // session token (raw); never persisted as-is + UserID string + CreatedAt time.Time + ExpiresAt time.Time + IP string + UA string +} + +// Host mirrors the denormalised hosts table. JSON columns (tags) are +// returned decoded into Go slices for ergonomics. +type Host struct { + ID string + Name string + OS string + Arch string + AgentVersion string + ResticVersion string + ProtocolVersion int + EnrolledAt time.Time + LastSeenAt *time.Time + Status string + RepoID *string + Tags []string + CurrentJobID *string + LastBackupAt *time.Time + LastBackupStatus *string + RepoSizeBytes int64 + SnapshotCount int + OpenAlertCount int + AppliedScheduleVersion int64 +} + +// EnrollmentToken is the issuer's view of a one-time token. The +// raw token is returned only at create time; the DB stores its hash. +type EnrollmentToken struct { + Raw string // populated on create only + TokenHash string + CreatedAt time.Time + ExpiresAt time.Time +} + +// AuditEntry mirrors the audit_log table. +type AuditEntry struct { + ID string + UserID *string + Actor string // user|agent|system + Action string + TargetKind *string + TargetID *string + TS time.Time + Payload json.RawMessage +} diff --git a/internal/store/users.go b/internal/store/users.go new file mode 100644 index 0000000..0567f2f --- /dev/null +++ b/internal/store/users.go @@ -0,0 +1,87 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" +) + +// CreateUser inserts a new user. The caller is responsible for +// generating an ID (typically a ULID) and hashing the password. +func (s *Store) CreateUser(ctx context.Context, u User) error { + _, err := s.db.ExecContext(ctx, + `INSERT INTO users (id, username, password_hash, role, created_at) + VALUES (?, ?, ?, ?, ?)`, + u.ID, u.Username, u.PasswordHash, string(u.Role), u.CreatedAt.UTC().Format(time.RFC3339Nano)) + if err != nil { + return fmt.Errorf("store: create user: %w", err) + } + return nil +} + +// GetUserByUsername looks up a user by their (case-sensitive) username. +// Returns ErrNotFound if no row matches. +func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, username, password_hash, role, created_at, last_login_at + FROM users WHERE username = ?`, username) + return scanUser(row) +} + +// GetUserByID looks up a user by id. Returns ErrNotFound on miss. +func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { + row := s.db.QueryRowContext(ctx, + `SELECT id, username, password_hash, role, created_at, last_login_at + FROM users WHERE id = ?`, id) + return scanUser(row) +} + +// CountUsers returns the total number of user rows. The first-run +// bootstrap uses this to detect a fresh install. +func (s *Store) CountUsers(ctx context.Context) (int, error) { + var n int + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&n); err != nil { + return 0, fmt.Errorf("store: count users: %w", err) + } + return n, nil +} + +// MarkUserLogin records a successful authentication. +func (s *Store) MarkUserLogin(ctx context.Context, id string, when time.Time) error { + _, err := s.db.ExecContext(ctx, + `UPDATE users SET last_login_at = ? WHERE id = ?`, + when.UTC().Format(time.RFC3339Nano), id) + if err != nil { + return fmt.Errorf("store: mark login: %w", err) + } + return nil +} + +func scanUser(row *sql.Row) (*User, error) { + var u User + var role string + var lastLogin sql.NullString + var created string + if err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &role, &created, &lastLogin); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("store: scan user: %w", err) + } + u.Role = Role(role) + t, err := time.Parse(time.RFC3339Nano, created) + if err != nil { + return nil, fmt.Errorf("store: parse created_at: %w", err) + } + u.CreatedAt = t + if lastLogin.Valid { + t, err := time.Parse(time.RFC3339Nano, lastLogin.String) + if err != nil { + return nil, fmt.Errorf("store: parse last_login_at: %w", err) + } + u.LastLoginAt = &t + } + return &u, nil +} diff --git a/internal/store/users_test.go b/internal/store/users_test.go new file mode 100644 index 0000000..64ddefd --- /dev/null +++ b/internal/store/users_test.go @@ -0,0 +1,158 @@ +package store + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestUserCRUD(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + now := time.Now().UTC() + u := User{ + ID: "u1", + Username: "alice", + PasswordHash: "$argon2id$...", + Role: RoleAdmin, + CreatedAt: now, + } + if err := s.CreateUser(ctx, u); err != nil { + t.Fatalf("create: %v", err) + } + + got, err := s.GetUserByUsername(ctx, "alice") + if err != nil { + t.Fatalf("get: %v", err) + } + if got.ID != "u1" || got.Role != RoleAdmin { + t.Errorf("unexpected user: %+v", got) + } + + // Username uniqueness is enforced by the schema. + if err := s.CreateUser(ctx, u); err == nil { + t.Error("duplicate username should fail") + } + + if _, err := s.GetUserByUsername(ctx, "bob"); !errors.Is(err, ErrNotFound) { + t.Errorf("missing user: want ErrNotFound, got %v", err) + } + + if err := s.MarkUserLogin(ctx, "u1", now); err != nil { + t.Fatalf("mark login: %v", err) + } + got, _ = s.GetUserByUsername(ctx, "alice") + if got.LastLoginAt == nil { + t.Error("last_login_at not updated") + } +} + +func TestCountUsers(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + n, _ := s.CountUsers(ctx) + if n != 0 { + t.Errorf("fresh db: want 0, got %d", n) + } + _ = s.CreateUser(ctx, User{ + ID: "u1", Username: "a", PasswordHash: "x", + Role: RoleAdmin, CreatedAt: time.Now(), + }) + n, _ = s.CountUsers(ctx) + if n != 1 { + t.Errorf("after insert: want 1, got %d", n) + } +} + +func TestSessionLifecycle(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + // Need a user for FK. + _ = s.CreateUser(ctx, User{ + ID: "u1", Username: "alice", PasswordHash: "x", + Role: RoleAdmin, CreatedAt: time.Now(), + }) + + now := time.Now().UTC() + sess := Session{ + UserID: "u1", + CreatedAt: now, + ExpiresAt: now.Add(time.Hour), + IP: "10.0.0.1", + UA: "test/1.0", + } + hash := "deadbeef" + "00000000000000000000000000000000000000000000000000000000" + if err := s.CreateSession(ctx, sess, hash); err != nil { + t.Fatalf("create: %v", err) + } + + got, err := s.LookupSession(ctx, hash) + if err != nil { + t.Fatalf("lookup: %v", err) + } + if got.UserID != "u1" { + t.Errorf("user mismatch: %s", got.UserID) + } + + // Expired sessions should not resolve. + expiredHash := "expired-hash" + expired := Session{ + UserID: "u1", + CreatedAt: now.Add(-2 * time.Hour), + ExpiresAt: now.Add(-time.Hour), + } + if err := s.CreateSession(ctx, expired, expiredHash); err != nil { + t.Fatalf("create expired: %v", err) + } + if _, err := s.LookupSession(ctx, expiredHash); !errors.Is(err, ErrNotFound) { + t.Errorf("expired session should look like ErrNotFound, got %v", err) + } + + if err := s.DeleteSession(ctx, hash); err != nil { + t.Fatalf("delete: %v", err) + } + if _, err := s.LookupSession(ctx, hash); !errors.Is(err, ErrNotFound) { + t.Errorf("deleted session: want ErrNotFound, got %v", err) + } + + n, err := s.PurgeExpiredSessions(ctx) + if err != nil { + t.Fatalf("purge: %v", err) + } + if n != 1 { + t.Errorf("purge should remove the 1 expired row, got %d", n) + } +} + +func TestEnrollmentTokenSingleUse(t *testing.T) { + t.Parallel() + s := openTestStore(t) + ctx := context.Background() + + hash := "tok-hash" + if err := s.CreateEnrollmentToken(ctx, hash, time.Hour); err != nil { + t.Fatalf("create: %v", err) + } + + // Need a host for FK. + _, err := s.DB().Exec(`INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`, + "h1", "host1", "linux", "amd64", time.Now().UTC().Format(time.RFC3339Nano)) + if err != nil { + t.Fatalf("insert host: %v", err) + } + + if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); err != nil { + t.Fatalf("consume: %v", err) + } + // Second consume must fail — the whole point of one-time tokens. + if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); !errors.Is(err, ErrNotFound) { + t.Errorf("re-consume: want ErrNotFound, got %v", err) + } +} diff --git a/tasks.md b/tasks.md index 90b6093..6cad0f3 100644 --- a/tasks.md +++ b/tasks.md @@ -20,6 +20,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. ## Phase 1 — MVP: enrollment, visibility, on-demand backup ### Server foundations + - [ ] **P1-01** (M) HTTP server scaffolding (`chi`, structured logging via `slog`, graceful shutdown) - [ ] **P1-02** (M) SQLite store layer (`modernc.org/sqlite`) + migrations (`golang-migrate` or hand-rolled) - [ ] **P1-03** (M) Schema for `users`, `sessions`, `hosts`, `repos`, `credentials`, `jobs`, `job_logs`, `snapshots`, `audit_log` @@ -29,6 +30,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P1-07** (M) Audit log writer + middleware ### Agent ↔ server protocol + - [ ] **P1-08** (M) Define shared API types in `internal/api` (Go structs, JSON tags) - [ ] **P1-09** (L) WebSocket transport (`nhooyr.io/websocket`), framed JSON envelopes, request/response correlation, ping/pong, reconnect with backoff - [ ] **P1-10** (M) Enrollment flow: `POST /api/agents/enroll` with one-time token → returns persistent bearer + cert pin @@ -36,6 +38,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P1-12** (S) Heartbeat handler (mark host offline after 90s without heartbeat) ### Agent foundations + - [ ] **P1-13** (M) Agent config file (`/etc/restic-manager/agent.yaml`); Windows path deferred to Phase 2 - [ ] **P1-14** (M) Service integration: systemd unit (Linux only in Phase 1; Windows service entrypoint deferred to Phase 2 — see P2-16) - [ ] **P1-15** (M) Outbound WS client (`github.com/coder/websocket`) with reconnect, server cert pinning, `protocol_version` advertisement in `hello` @@ -43,6 +46,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P1-17** (S) Host metadata collection (OS, arch, hostname, restic version, agent version, protocol_version) ### Run-now backup + - [ ] **P1-18** (L) Job lifecycle: queued → running → succeeded/failed/cancelled, persisted with logs - [ ] **P1-19** (M) Server endpoint `POST /api/hosts/:id/jobs` to dispatch a `backup` command - [ ] **P1-20** (M) Agent executes `restic backup`, streams stdout/stderr + parsed JSON events back as `job.progress` / `log.stream` @@ -50,6 +54,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P1-22** (S) Snapshot listing: `restic snapshots --json`, cached projection table, refresh after each backup ### UI (HTMX + Tailwind) + - [ ] **P1-23** (M) Base layout, login page, session-aware nav - [ ] **P1-24** (M) Dashboard: host cards (status dot, last backup, repo size) - [ ] **P1-25** (M) Host detail page: snapshots tab + run-now button @@ -58,10 +63,12 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P1-28** (S) Tailwind build via `tailwindcss` standalone binary (no Node) ### Install scripts + - [ ] **P1-29** (M) `install.sh` (Linux): detects arch, downloads agent, installs systemd unit, enrolls. Also detects existing restic timers/cron (`systemctl list-timers --all | grep -i restic`, `crontab -l`, `/etc/cron.d/`, `/etc/cron.daily/`) and prints them with the disable commands — does **not** auto-disable, since heuristic matches could be unrelated tooling - [ ] **P1-31** (S) Server endpoint to serve agent binaries + install scripts (signed) ### Phase 1 acceptance + - One Linux host can enroll, appear in the dashboard, and a backup can be triggered from the UI with live log streaming. Snapshots list updates after success. - Windows binary builds cleanly in CI (`.gitea/workflows/ci.yml`) but is not service-tested or installer-shipped in Phase 1 — that lands in Phase 2 (P2-16, P2-17). - Agent ↔ server `protocol_version` handshake rejects mismatched versions with a clear error rather than failing on JSON parse. @@ -89,6 +96,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P2-17** (M) `install.ps1` (Windows): downloads agent, installs as service, enrolls; detects existing scheduled tasks named `*restic*` and prints them for manual review ### Phase 2 acceptance + - Schedules created in UI run on agents on time; retention is applied; admin can prune from UI; repo health visible per host. Pre/post hooks fire correctly (verified with a Docker stop/start example and a `mysqldump` example) and are rejected on non-backup schedule kinds. Bandwidth limits honoured. - A Windows host can enroll, appear in the dashboard, and run a backup with live log streaming — closing the cross-platform gap left by Phase 1. @@ -107,6 +115,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P3-09** (S) `diff` between two snapshots in UI ### Phase 3 acceptance + - A file deleted on a host can be restored from the UI in under 2 minutes. A failed backup raises an alert via the configured channel within 60s. --- @@ -124,6 +133,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P4-09** (S) Document Prometheus integration + sample Grafana dashboard JSON ### Phase 4 acceptance + - Non-admin users see an appropriately limited UI. Agents upgrade via apt/choco with one admin-triggered action. OIDC login works against at least one provider (Authelia or Authentik). Prometheus can scrape `/metrics` and the sample Grafana dashboard renders with live data. --- @@ -139,6 +149,7 @@ Sizes: **S** = under a day, **M** = 1–3 days, **L** = 3–7 days. - [ ] **P5-07** (S) Sample `docker-compose.yml` with TLS via Caddy sidecar (also demonstrates `RM_TRUSTED_PROXY`) ### Phase 5 acceptance + - A stranger can read the docs and stand up a working install in under 30 minutes. ---