Compare commits
103 Commits
v0.9.0
..
aa80a3418e
| Author | SHA1 | Date | |
|---|---|---|---|
| aa80a3418e | |||
| ac9d7b92ed | |||
| 556d65d77f | |||
| 7ee8d2311b | |||
| e4dd7f96d6 | |||
| 6afde3ce23 | |||
| 775e988340 | |||
| e2cf9a68f6 | |||
| 1e212db24e | |||
| a9185424d3 | |||
| 9c5037ec54 | |||
| ed839aacb4 | |||
| 7d2c2ae1c2 | |||
| a131419b1a | |||
| 9d727a7b3a | |||
| 1cbc856514 | |||
| fb24e42c6e | |||
| a899cc2d04 | |||
| ef2a30a82d | |||
| e7960151fb | |||
| 2d27a23e99 | |||
| 9d9773cad4 | |||
| b3d033fa11 | |||
| b7033fcfcd | |||
| b1b0d9d1e9 | |||
| f711593549 | |||
| dfc2cd314d | |||
| 5e2b88c6dd | |||
| 768972d870 | |||
| 82a73fad85 | |||
| 26bb881c12 | |||
| 3873bd9d34 | |||
| 1bb31b9c49 | |||
| 4985050a0a | |||
| 1c7b471e75 | |||
| 88216d29d0 | |||
| 0ae62261e3 | |||
| dd7b37a5c1 | |||
| 694d9d9bf3 | |||
| 2d40002355 | |||
| e871b05b38 | |||
| 18a9f6624e | |||
| 2a8dd1eba2 | |||
| fab99b4a38 | |||
| ffba7371c5 | |||
| 4035c44be3 | |||
| d62b173712 | |||
| 8b91d3037c | |||
| 64d2fcf7a3 | |||
| 67ca769686 | |||
| dede74fd3a | |||
| 0ed9c3d1ec | |||
| a535822ff3 | |||
| 21841e38c4 | |||
| e968abc042 | |||
| 713bc4a2bb | |||
| d000fe7ec1 | |||
| 337dcc0f0f | |||
| 813158b3d6 | |||
| 5667cdf13a | |||
| 666af41f46 | |||
| 7a7cac588c | |||
| fdecde0d5c | |||
| f62a90b4b3 | |||
| 1b947f5a2c | |||
| c565a7abd1 | |||
| 7e49b62e0e | |||
| e0037f0026 | |||
| 72d8081b0d | |||
| 8a05969953 | |||
| 148e61b33b | |||
| 160d788bae | |||
| 6450bf1b88 | |||
| 946b6db137 | |||
| 4b075840a1 | |||
| ee3ee241ea | |||
| 12b72e7dde | |||
| bd434bd1d0 | |||
| 26a2b85e13 | |||
| dad8c7fe99 | |||
| ee16bc7ce7 | |||
| 229f89fee2 | |||
| 136e1a1d8f | |||
| f9c2351ab6 | |||
| 81c7825937 | |||
| b6cfa99413 | |||
| 2418e585db | |||
| 5d1951ad94 | |||
| ec276dbc91 | |||
| 0ba56ed30d | |||
| e58917106d | |||
| 6c9558c703 | |||
| 3904a78f14 | |||
| 41a4043af3 | |||
| 77a305d064 | |||
| 95b49ecab9 | |||
| e8eccd20c2 | |||
| f34773b505 | |||
| 84fd31ccaa | |||
| c275f4ff4c | |||
| 595546afb9 | |||
| c9368de904 | |||
| 7612687a14 |
+6
-76
@@ -1,47 +1,3 @@
|
||||
# CI workflow — runs on every PR into main.
|
||||
#
|
||||
# Notes for anyone editing this file:
|
||||
#
|
||||
# Self-hosted runner expectations
|
||||
# The Gitea runners are provisioned out-of-band (the infra team owns
|
||||
# the script). Each runner host bind-mounts persistent volumes for
|
||||
# /root/go/pkg/mod (GOMODCACHE), /root/.cache/go-build (GOCACHE), and
|
||||
# /root/.cache/act (action clones) into every job container. As a
|
||||
# result:
|
||||
# * `cache: true` on actions/setup-go is intentionally OMITTED — the
|
||||
# action would otherwise tar/untar GOMODCACHE+GOCACHE through the
|
||||
# Gitea cache backend on every job, undoing the host-volume cache
|
||||
# and adding ~10s of redundant zstd round-trip per job.
|
||||
# * Common GitHub actions (actions/checkout, actions/setup-go,
|
||||
# actions/upload-artifact, golangci/golangci-lint-action) are
|
||||
# pre-cloned into /root/.cache/act on the runner, so the per-job
|
||||
# "git clone https://github.com/actions/..." step is a fetch, not
|
||||
# a full clone.
|
||||
# * golangci-lint is pre-installed at /usr/local/bin/golangci-lint
|
||||
# on the runner (latest v2.x). The golangci-lint-action below
|
||||
# still pins a specific version and re-downloads — that's fine
|
||||
# (deterministic CI > marginal speed) but means the host-installed
|
||||
# binary is currently unused. Drop the `version:` arg below to
|
||||
# use the host-installed one if you want to trade determinism
|
||||
# for speed.
|
||||
#
|
||||
# Build matrix
|
||||
# Linux amd64 + arm64 + Windows amd64. CGO_ENABLED=0 throughout —
|
||||
# modernc.org/sqlite is pure-Go so no cross-compile toolchain is
|
||||
# needed. -trimpath + -ldflags="-s -w" for reproducible, smaller
|
||||
# binaries.
|
||||
#
|
||||
# Go version
|
||||
# The GO_VERSION env var anchors all three jobs. Floor is set by the
|
||||
# heaviest dep (modernc.org/sqlite v1.50+ requires Go 1.23+ today;
|
||||
# we run 1.25 so golangci-lint's Go-version compatibility check is
|
||||
# happy — see the version pin in the lint job).
|
||||
#
|
||||
# upload-artifact
|
||||
# Pinned at v3 historically; v3 was deprecated upstream. v4 should
|
||||
# work but hasn't been validated against this runner's act_runner
|
||||
# version yet. Bump when convenient.
|
||||
|
||||
name: CI
|
||||
|
||||
on:
|
||||
@@ -49,49 +5,23 @@ on:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
# Floor is set by the heaviest dep (modernc.org/sqlite v1.50+).
|
||||
GO_VERSION: "1.25"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
# Sharded by package group. server/http and store are the two
|
||||
# heavy packages (~156s and ~75s in CI respectively under
|
||||
# `-race`); pulling them onto their own runners lets each shard
|
||||
# have all CPUs to itself instead of CPU-starving each other on
|
||||
# one runner. The third shard ("rest") covers everything else.
|
||||
name: Test (${{ matrix.name }})
|
||||
name: Test (linux/amd64)
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: server-http
|
||||
packages: ./internal/server/http/...
|
||||
- name: store
|
||||
packages: ./internal/store/...
|
||||
- name: rest
|
||||
# Computed at runtime — see the "go test" step below.
|
||||
packages: ""
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
# cache: true intentionally omitted — see header notes.
|
||||
cache: true
|
||||
- name: go vet
|
||||
run: go vet ./...
|
||||
- name: go test
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if [ -n "${{ matrix.packages }}" ]; then
|
||||
pkgs="${{ matrix.packages }}"
|
||||
else
|
||||
# "rest" shard: everything except the dedicated shards.
|
||||
pkgs=$(go list ./... \
|
||||
| grep -v '/internal/server/http$' \
|
||||
| grep -v '/internal/store$')
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
go test -race -coverprofile=coverage.out $pkgs
|
||||
run: go test -race -coverprofile=coverage.out ./...
|
||||
- name: coverage summary
|
||||
run: go tool cover -func=coverage.out | tail -1
|
||||
|
||||
@@ -103,7 +33,7 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
# cache: true intentionally omitted — see header notes.
|
||||
cache: true
|
||||
- uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
# Must be built against the same Go release as go.mod targets,
|
||||
@@ -133,7 +63,7 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
# cache: true intentionally omitted — see header notes.
|
||||
cache: true
|
||||
- name: build server + agent
|
||||
env:
|
||||
GOOS: ${{ matrix.goos }}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
# Release workflow — P5-03 (docker-only release path).
|
||||
#
|
||||
# Spec : docs/superpowers/specs/2026-05-05-p5-03-docker-only-release.md
|
||||
# Plan : docs/superpowers/plans/2026-05-05-p5-03-docker-only-release.md
|
||||
#
|
||||
# What it does
|
||||
# * Triggered by either:
|
||||
# - tag push matching v[0-9]+.[0-9]+.[0-9]+ (real release), or
|
||||
# - workflow_dispatch (snapshot iteration without tagging).
|
||||
# * Cross-builds a multi-arch (linux/amd64,linux/arm64) image of the
|
||||
# server, with three agent binaries (linux amd64+arm64, windows amd64)
|
||||
# plus install.sh / install.ps1 / the systemd unit baked in under
|
||||
# /opt/restic-manager/dist (the read-only fallback path the server
|
||||
# handlers use when <DataDir>/... is empty).
|
||||
# * Pushes to this Gitea instance's container registry under
|
||||
# <gitea-host>/<owner>/restic-manager.
|
||||
#
|
||||
# Tag fan-out
|
||||
# * tag push: :vX.Y.Z, :X.Y, :X
|
||||
# * tag push and X >= 1: also :latest
|
||||
# * workflow_dispatch: only :snapshot-<shortsha>; nothing else moves.
|
||||
#
|
||||
# Why no goreleaser
|
||||
# The architecture already routes agent distribution through the
|
||||
# server's /agent/binary endpoint. The image is the only deliverable;
|
||||
# binary archives would just be a second source of truth.
|
||||
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
REGISTRY: gitea.dcglab.co.uk
|
||||
IMAGE_NAME: ${{ gitea.repository }}
|
||||
|
||||
jobs:
|
||||
image:
|
||||
name: Build + push image
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: docker/setup-qemu-action@v3
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Gitea registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ gitea.actor }}
|
||||
password: ${{ secrets.DEV_TOKEN }}
|
||||
|
||||
- name: Compute tags + version
|
||||
id: meta
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
REG="${REGISTRY}/${IMAGE_NAME}"
|
||||
DATE="$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "push" ] && [ "${GITHUB_REF_TYPE}" = "tag" ]; then
|
||||
TAG="${GITHUB_REF_NAME}" # vX.Y.Z
|
||||
VER="${TAG#v}" # X.Y.Z
|
||||
MAJOR="${VER%%.*}"
|
||||
MINOR="${VER#${MAJOR}.}"; MINOR="${MINOR%%.*}"
|
||||
|
||||
TAGS="${REG}:${TAG}"
|
||||
TAGS="${TAGS},${REG}:${MAJOR}.${MINOR}"
|
||||
TAGS="${TAGS},${REG}:${MAJOR}"
|
||||
# Pre-1.0 holds back :latest by design; operators must
|
||||
# pin a version explicitly until v1.0.0.
|
||||
if [ "${MAJOR}" -ge 1 ]; then
|
||||
TAGS="${TAGS},${REG}:latest"
|
||||
fi
|
||||
VERSION="${TAG}"
|
||||
else
|
||||
TAGS="${REG}:snapshot-${SHORT_SHA}"
|
||||
VERSION="0.0.0-snapshot-${SHORT_SHA}"
|
||||
fi
|
||||
|
||||
{
|
||||
echo "tags=${TAGS}"
|
||||
echo "version=${VERSION}"
|
||||
echo "date=${DATE}"
|
||||
} >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Build + push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: deploy/Dockerfile.server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.meta.outputs.version }}
|
||||
COMMIT=${{ gitea.sha }}
|
||||
DATE=${{ steps.meta.outputs.date }}
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.meta.outputs.version }}
|
||||
org.opencontainers.image.revision=${{ gitea.sha }}
|
||||
org.opencontainers.image.created=${{ steps.meta.outputs.date }}
|
||||
-11
@@ -26,18 +26,7 @@ coverage.html
|
||||
.env.local
|
||||
*.local
|
||||
|
||||
# Local docker-compose for the dev/test bench. Has host-specific IPs,
|
||||
# hostnames, and ports — never committed; the canonical reference
|
||||
# deployment lives in deploy/.
|
||||
/compose.yaml
|
||||
/compose.override.yaml
|
||||
|
||||
# Local diagnostic helpers (never shipped). Go's build tooling already
|
||||
# skips paths beginning with _ or ., but ignore explicitly so nothing
|
||||
# checked in here can leak into a release tarball.
|
||||
/_diag/
|
||||
|
||||
# Dev-only one-shot binaries (cmd/_*) — never shipped. Go's build
|
||||
# tooling already skips paths starting with _, but ignore explicitly
|
||||
# so an accidental `git add cmd/.` can't sneak them into a release.
|
||||
/cmd/_*/
|
||||
|
||||
+1
-1
@@ -26,7 +26,7 @@ linters:
|
||||
- name: exported
|
||||
arguments: ["disableStutteringCheck"]
|
||||
misspell:
|
||||
locale: UK
|
||||
locale: US
|
||||
exclusions:
|
||||
rules:
|
||||
- path: _test\.go
|
||||
|
||||
@@ -2,19 +2,6 @@
|
||||
|
||||
Project-specific rules for Claude when working in this repo.
|
||||
|
||||
## Commands
|
||||
|
||||
Is the user types in any of the following, follow the instructions in the table
|
||||
|
||||
| Command | Action |
|
||||
| --- | --- |
|
||||
| :release | trigger subagent to commit (if needed), push (if needed), raise PR, wait for PR to pass or fail. If fail, report back. If pass, merge in to main |
|
||||
|
||||
## Repo
|
||||
|
||||
The repo lives inside a Gitea instance; `tea` CLI is available for use by agents
|
||||
|
||||
|
||||
## Run `go vet` before every commit
|
||||
|
||||
CI runs `go vet ./...` and will fail the build on any vet error.
|
||||
@@ -56,8 +43,6 @@ cp bin/restic-manager-agent \
|
||||
/tmp/rm-smoke/data/agent-binaries/restic-manager-agent-linux-amd64
|
||||
cp deploy/install/install.sh \
|
||||
/tmp/rm-smoke/data/install/install.sh
|
||||
cp deploy/install/install.ps1 \
|
||||
/tmp/rm-smoke/data/install/install.ps1
|
||||
cp deploy/install/restic-manager-agent.service \
|
||||
/tmp/rm-smoke/data/install/restic-manager-agent.service
|
||||
|
||||
|
||||
@@ -5,11 +5,9 @@ BIN_DIR := bin
|
||||
SERVER_BIN := $(BIN_DIR)/restic-manager-server
|
||||
AGENT_BIN := $(BIN_DIR)/restic-manager-agent
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo dev)
|
||||
COMMIT ?= $(shell git rev-parse HEAD 2>/dev/null || echo none)
|
||||
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
LDFLAGS := -s -w -X main.version=$(VERSION) -X main.commit=$(COMMIT) -X main.date=$(DATE)
|
||||
LDFLAGS := -s -w -X main.version=$(VERSION)
|
||||
GOFLAGS := -trimpath
|
||||
DOCKER_IMAGE ?= gitea.dcglab.co.uk/steve/restic-manager
|
||||
DOCKER_IMAGE ?= ghcr.io/dcglab/restic-manager
|
||||
DOCKER_TAG ?= dev
|
||||
|
||||
# Tailwind standalone CLI — single binary, no Node toolchain.
|
||||
@@ -86,11 +84,7 @@ run-agent: agent ## Build and run the agent
|
||||
$(AGENT_BIN)
|
||||
|
||||
docker: ## Build the server Docker image
|
||||
docker build -f deploy/Dockerfile.server \
|
||||
--build-arg VERSION=$(VERSION) \
|
||||
--build-arg COMMIT=$(COMMIT) \
|
||||
--build-arg DATE=$(DATE) \
|
||||
-t $(DOCKER_IMAGE):$(DOCKER_TAG) .
|
||||
docker build -f deploy/Dockerfile.server --build-arg VERSION=$(VERSION) -t $(DOCKER_IMAGE):$(DOCKER_TAG) .
|
||||
|
||||
release: ## Cross-compile for all supported platforms
|
||||
@mkdir -p $(BIN_DIR)
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
# The ask!
|
||||
|
||||
I have numerous servers deployed out in a lab, mainly Linux but some Windows
|
||||
All have restic installed on them
|
||||
I need to build a browser based management service that allows me to have a central single-plane-of-glass to monitor and manage all teh endpoints
|
||||
All endpoints will be enabled for SSH (unless other methods are better?)
|
||||
|
||||
Plan out how we would go about this please?
|
||||
@@ -1,262 +0,0 @@
|
||||
// announce.go — agent-side announce-and-approve enrolment (P2-18c).
|
||||
//
|
||||
// Run path: when the agent has no AgentToken set but RM_SERVER is
|
||||
// configured (and no -enroll-token was supplied), main() switches
|
||||
// into announce mode:
|
||||
// 1. Load (or mint+persist) an Ed25519 keypair in agent.yaml.
|
||||
// 2. POST {hostname, os, arch, agent_version, restic_version,
|
||||
// public_key} to /api/agents/announce.
|
||||
// 3. Print the fingerprint to stderr in a copy-friendly banner so
|
||||
// the operator can compare it against the dashboard.
|
||||
// 4. Open /ws/agent/pending?pending_id=…, sign the nonce with our
|
||||
// private key, wait for an `enrolled` message.
|
||||
// 5. On enrolled: persist the bearer + repo creds, return; main()
|
||||
// then drops into the normal WS run loop with the new bearer.
|
||||
// 6. On reject: server closes the socket with code 4001; we exit
|
||||
// with a clear message.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/secrets"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/sysinfo"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// announceRequest mirrors the server's announceRequest. Duplicated
|
||||
// here so cmd/agent stays decoupled from the http package.
|
||||
type announceRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
PublicKey string `json:"public_key"`
|
||||
}
|
||||
|
||||
type announceResponse struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
HostnameCollision bool `json:"hostname_collision"`
|
||||
}
|
||||
|
||||
type pendingNonceMessage struct {
|
||||
Type string `json:"type"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
type pendingSignedMessage struct {
|
||||
Type string `json:"type"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
|
||||
type pendingEnrolledMessage struct {
|
||||
Type string `json:"type"`
|
||||
HostID string `json:"host_id"`
|
||||
Bearer string `json:"bearer"`
|
||||
}
|
||||
|
||||
// doAnnounce runs the full announce → wait-for-accept flow. On
|
||||
// success, persists the bearer + host_id into cfg + writes secrets
|
||||
// for the repo creds the admin supplied at accept time. Returns
|
||||
// only after the bearer has landed (or on hard error / reject).
|
||||
func doAnnounce(serverURL string, cfg *config.Config, agentVersion string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
// Ensure we have a keypair.
|
||||
priv, pub, err := loadOrMintAnnounceKey(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: keypair: %w", err)
|
||||
}
|
||||
fingerprint := store.FingerprintForKey(pub)
|
||||
|
||||
snap, err := sysinfo.Collect(ctx, cfg.ResticPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: sysinfo: %w", err)
|
||||
}
|
||||
|
||||
// POST /api/agents/announce.
|
||||
body, _ := json.Marshal(announceRequest{
|
||||
Hostname: snap.Hostname, OS: string(snap.OS), Arch: string(snap.Arch),
|
||||
AgentVersion: agentVersion, ResticVersion: snap.ResticVersion,
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
})
|
||||
req, _ := stdhttp.NewRequestWithContext(ctx, "POST",
|
||||
strings.TrimRight(serverURL, "/")+"/api/agents/announce",
|
||||
strings.NewReader(string(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: POST: %w", err)
|
||||
}
|
||||
rawBody := readAllShort(res)
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusOK {
|
||||
return fmt.Errorf("announce: server returned %d: %s", res.StatusCode, rawBody)
|
||||
}
|
||||
var ar announceResponse
|
||||
if err := json.Unmarshal(rawBody, &ar); err != nil {
|
||||
return fmt.Errorf("announce: parse response: %w", err)
|
||||
}
|
||||
|
||||
// Print the fingerprint banner.
|
||||
fmt.Fprintln(os.Stderr, strings.Repeat("=", 64))
|
||||
fmt.Fprintln(os.Stderr, " Restic-manager: announce-and-approve enrolment")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Hostname : "+snap.Hostname)
|
||||
fmt.Fprintln(os.Stderr, " Server : "+serverURL)
|
||||
fmt.Fprintln(os.Stderr, " Pending ID : "+ar.PendingID)
|
||||
fmt.Fprintln(os.Stderr, " Fingerprint : "+fingerprint)
|
||||
if ar.HostnameCollision {
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " WARNING: another pending host already uses this hostname.")
|
||||
fmt.Fprintln(os.Stderr, " Confirm the fingerprint above matches what you see in the UI.")
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Compare the fingerprint with the one in the UI before accepting.")
|
||||
fmt.Fprintln(os.Stderr, " Waiting for an admin to accept (1 hour timeout)…")
|
||||
fmt.Fprintln(os.Stderr, strings.Repeat("=", 64))
|
||||
|
||||
// Open /ws/agent/pending and run the nonce-sign handshake.
|
||||
wsURL := wsURLFromHTTP(serverURL) + "/ws/agent/pending?pending_id=" + ar.PendingID
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
c, dialRes, err := websocket.Dial(dialCtx, wsURL, nil)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: dial pending ws: %w", err)
|
||||
}
|
||||
if dialRes != nil && dialRes.Body != nil {
|
||||
_ = dialRes.Body.Close()
|
||||
}
|
||||
defer func() { _ = c.CloseNow() }()
|
||||
|
||||
// Read nonce.
|
||||
rctx, rcancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
_, raw, err := c.Read(rctx)
|
||||
rcancel()
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: read nonce: %w", err)
|
||||
}
|
||||
var nm pendingNonceMessage
|
||||
if err := json.Unmarshal(raw, &nm); err != nil {
|
||||
return fmt.Errorf("announce: parse nonce: %w", err)
|
||||
}
|
||||
nonce, err := base64.StdEncoding.DecodeString(nm.Nonce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("announce: decode nonce: %w", err)
|
||||
}
|
||||
sig := ed25519.Sign(priv, nonce)
|
||||
reply, _ := json.Marshal(pendingSignedMessage{
|
||||
Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig),
|
||||
})
|
||||
wctx, wcancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
if err := c.Write(wctx, websocket.MessageText, reply); err != nil {
|
||||
wcancel()
|
||||
return fmt.Errorf("announce: write signed nonce: %w", err)
|
||||
}
|
||||
wcancel()
|
||||
|
||||
// Block until enrolled (or reject / disconnect).
|
||||
rctx2, rcancel2 := context.WithTimeout(ctx, 1*time.Hour)
|
||||
defer rcancel2()
|
||||
_, raw2, err := c.Read(rctx2)
|
||||
if err != nil {
|
||||
// CloseError with our reject code 4001 = admin rejected.
|
||||
var ce websocket.CloseError
|
||||
if errors.As(err, &ce) && ce.Code == 4001 {
|
||||
return errors.New("announce: rejected by admin")
|
||||
}
|
||||
return fmt.Errorf("announce: wait for enrolled: %w", err)
|
||||
}
|
||||
var em pendingEnrolledMessage
|
||||
if err := json.Unmarshal(raw2, &em); err != nil {
|
||||
return fmt.Errorf("announce: parse enrolled: %w", err)
|
||||
}
|
||||
if em.Type != "enrolled" || em.Bearer == "" {
|
||||
return fmt.Errorf("announce: bad enrolled payload: %s", raw2)
|
||||
}
|
||||
|
||||
// Persist the bearer + host_id.
|
||||
cfg.ServerURL = serverURL
|
||||
cfg.HostID = em.HostID
|
||||
cfg.AgentToken = em.Bearer
|
||||
if err := cfg.EnsureSecretsKey(); err != nil {
|
||||
return fmt.Errorf("announce: mint secrets key: %w", err)
|
||||
}
|
||||
// Note: repo creds aren't pushed in the enrolled message — the
|
||||
// server pushes them via `config.update` on first WS hello. The
|
||||
// secrets store will start empty and fill in then.
|
||||
if err := cfg.Save(); err != nil {
|
||||
return fmt.Errorf("announce: save config: %w", err)
|
||||
}
|
||||
// Touch the secrets store so it exists with the right perms.
|
||||
keyBytes, _ := cfg.SecretsKeyBytes()
|
||||
if _, err := secrets.New(cfg.ResolvedSecretsPath(), keyBytes); err != nil {
|
||||
return fmt.Errorf("announce: open secrets store: %w", err)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "Accepted. Bearer persisted; reconnecting via the standard WS.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadOrMintAnnounceKey returns the (priv, pub) keypair, generating
|
||||
// + persisting one when AnnounceKey is empty. The private key holds
|
||||
// the public half in its tail 32 bytes per ed25519 convention.
|
||||
func loadOrMintAnnounceKey(cfg *config.Config) (ed25519.PrivateKey, ed25519.PublicKey, error) {
|
||||
if cfg.AnnounceKey != "" {
|
||||
raw, err := base64.StdEncoding.DecodeString(cfg.AnnounceKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("decode AnnounceKey: %w", err)
|
||||
}
|
||||
if len(raw) != ed25519.PrivateKeySize {
|
||||
return nil, nil, fmt.Errorf("AnnounceKey must be %d bytes, got %d",
|
||||
ed25519.PrivateKeySize, len(raw))
|
||||
}
|
||||
priv := ed25519.PrivateKey(raw)
|
||||
pub := priv.Public().(ed25519.PublicKey)
|
||||
return priv, pub, nil
|
||||
}
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate keypair: %w", err)
|
||||
}
|
||||
cfg.AnnounceKey = base64.StdEncoding.EncodeToString(priv)
|
||||
if err := cfg.Save(); err != nil {
|
||||
return nil, nil, fmt.Errorf("persist AnnounceKey: %w", err)
|
||||
}
|
||||
return priv, pub, nil
|
||||
}
|
||||
|
||||
// wsURLFromHTTP swaps the http(s) scheme for ws(s).
|
||||
func wsURLFromHTTP(httpURL string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(httpURL, "https://"):
|
||||
return "wss://" + strings.TrimPrefix(httpURL, "https://")
|
||||
case strings.HasPrefix(httpURL, "http://"):
|
||||
return "ws://" + strings.TrimPrefix(httpURL, "http://")
|
||||
default:
|
||||
return httpURL
|
||||
}
|
||||
}
|
||||
|
||||
// readAllShort reads up to 64KB of the response body. The announce
|
||||
// response is small; we cap to avoid pathological server replies.
|
||||
func readAllShort(res *stdhttp.Response) []byte {
|
||||
buf := make([]byte, 64*1024)
|
||||
n, _ := res.Body.Read(buf)
|
||||
return buf[:n]
|
||||
}
|
||||
+55
-338
@@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -17,18 +16,13 @@ import (
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/runner"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/scheduler"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/secrets"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/service"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/sysinfo"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/agent/wsclient"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/restic"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
date = "unknown"
|
||||
)
|
||||
var version = "dev"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -38,27 +32,6 @@ func main() {
|
||||
}
|
||||
|
||||
func run() error {
|
||||
// Optional first positional verb for SCM control on Windows.
|
||||
// `restic-manager-agent install|uninstall|start|stop` route into
|
||||
// the service package; everything else falls through to the
|
||||
// flag-driven default (which is what systemd / interactive runs
|
||||
// hit). On non-Windows builds these verbs return a clear error.
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "install":
|
||||
return service.Install()
|
||||
case "uninstall":
|
||||
return service.Uninstall()
|
||||
case "start":
|
||||
return service.Start()
|
||||
case "stop":
|
||||
return service.Stop()
|
||||
case "run":
|
||||
// Strip the verb so flag.Parse sees the rest unchanged.
|
||||
os.Args = append([]string{os.Args[0]}, os.Args[2:]...)
|
||||
}
|
||||
}
|
||||
|
||||
configPath := flag.String("config", config.DefaultPath(), "path to agent.yaml")
|
||||
enrollServer := flag.String("enroll-server", "", "server URL (used with -enroll-token to perform first-run enrollment)")
|
||||
enrollToken := flag.String("enroll-token", "", "one-time enrollment token (operator copies this from the UI)")
|
||||
@@ -66,7 +39,7 @@ func run() error {
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("restic-manager-agent %s (commit %s, built %s)\n", version, commit, date)
|
||||
fmt.Println("restic-manager-agent", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -85,17 +58,8 @@ func run() error {
|
||||
return doEnroll(*enrollServer, *enrollToken, cfg, version)
|
||||
}
|
||||
|
||||
// Announce-and-approve: -enroll-server set, no token, agent not
|
||||
// yet enrolled. Run the announce flow inline; on success the cfg
|
||||
// has the bearer + host_id and we drop into the normal run loop.
|
||||
if !cfg.Enrolled() && *enrollServer != "" {
|
||||
if err := doAnnounce(*enrollServer, cfg, version); err != nil {
|
||||
return fmt.Errorf("announce: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !cfg.Enrolled() {
|
||||
return fmt.Errorf("agent is not enrolled; run with -enroll-server (and either -enroll-token or wait for admin to accept the announce) first (config %q)", *configPath)
|
||||
return fmt.Errorf("agent is not enrolled; run with -enroll-server and -enroll-token first (config %q)", *configPath)
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
@@ -115,12 +79,6 @@ func run() error {
|
||||
|
||||
resticBin, _ := restic.Locate(cfg.ResticPath) // empty is fine; commands fail with a clear error later
|
||||
|
||||
// Probe the actual restic binary for restore-flag support. We used
|
||||
// to gate --no-ownership on a SemVer comparison (added in 0.17),
|
||||
// but a restic 0.18.1 build was observed in the wild that still
|
||||
// rejects the flag. The help text is the only reliable signal.
|
||||
resticSupportsNoOwnership := restic.SupportsRestoreNoOwnership(ctx, resticBin)
|
||||
|
||||
// Open the secrets store. If the agent is enrolled but has no
|
||||
// secrets key yet (legacy YAML), mint one and migrate any
|
||||
// plaintext repo fields into the encrypted blob.
|
||||
@@ -145,11 +103,9 @@ func run() error {
|
||||
}
|
||||
|
||||
d := &dispatcher{
|
||||
resticBin: resticBin,
|
||||
resticVer: snap.ResticVersion,
|
||||
resticSupportsNoOwnership: resticSupportsNoOwnership,
|
||||
secrets: sec,
|
||||
scheduler: scheduler.New(),
|
||||
resticBin: resticBin,
|
||||
secrets: sec,
|
||||
scheduler: scheduler.New(),
|
||||
}
|
||||
if err := wsclient.Run(ctx, wsCfg, d.handle); err != nil {
|
||||
return fmt.Errorf("ws run: %w", err)
|
||||
@@ -211,58 +167,9 @@ func openSecretsStore(cfg *config.Config) (*secrets.Store, error) {
|
||||
// secrets store on each job — config.update writes through to disk,
|
||||
// so a job dispatched in the same session sees the latest values.
|
||||
type dispatcher struct {
|
||||
resticBin string
|
||||
resticVer string // e.g. "0.17.1"; empty if restic isn't installed yet
|
||||
resticSupportsNoOwnership bool // captured at startup from `restic restore --help`
|
||||
secrets *secrets.Store
|
||||
scheduler *scheduler.Scheduler
|
||||
|
||||
// Bandwidth caps in KB/s pushed via config.update. Mutated under
|
||||
// bwMu by the config.update handler; read by runJob when building
|
||||
// the runner. <=0 means "no cap" (do not pass --limit-* to restic).
|
||||
// Per-job overrides on CommandRunPayload take precedence.
|
||||
bwMu sync.Mutex
|
||||
bwUpKBps int
|
||||
bwDownKBps int
|
||||
|
||||
// Per-running-job cancellation handles. Populated when runJob
|
||||
// spawns the goroutine, removed when it returns. Looked up by
|
||||
// the command.cancel handler (server → agent) to abort an
|
||||
// in-flight restic invocation.
|
||||
cancelMu sync.Mutex
|
||||
cancels map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
// trackJob registers a cancel func for an in-flight job and returns a
|
||||
// cleanup that removes it. Call cleanup when the job goroutine exits
|
||||
// regardless of outcome — runs even on panic.
|
||||
func (d *dispatcher) trackJob(jobID string, cancel context.CancelFunc) func() {
|
||||
d.cancelMu.Lock()
|
||||
if d.cancels == nil {
|
||||
d.cancels = make(map[string]context.CancelFunc)
|
||||
}
|
||||
d.cancels[jobID] = cancel
|
||||
d.cancelMu.Unlock()
|
||||
return func() {
|
||||
d.cancelMu.Lock()
|
||||
delete(d.cancels, jobID)
|
||||
d.cancelMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// cancelJob fires the cancel func for jobID if there is one and
|
||||
// returns whether the job was actually known. The runner is expected
|
||||
// to surface the resulting context.Canceled as a JobCancelled status
|
||||
// in its job.finished envelope (see runner.sendFinished).
|
||||
func (d *dispatcher) cancelJob(jobID string) bool {
|
||||
d.cancelMu.Lock()
|
||||
cancel, ok := d.cancels[jobID]
|
||||
d.cancelMu.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
cancel()
|
||||
return true
|
||||
resticBin string
|
||||
secrets *secrets.Store
|
||||
scheduler *scheduler.Scheduler
|
||||
}
|
||||
|
||||
func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.Sender) error {
|
||||
@@ -275,29 +182,8 @@ func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.S
|
||||
return d.runJob(ctx, p, tx)
|
||||
|
||||
case api.MsgCommandCancel:
|
||||
var p api.CommandCancelPayload
|
||||
if err := env.UnmarshalPayload(&p); err != nil {
|
||||
return fmt.Errorf("command.cancel: %w", err)
|
||||
}
|
||||
if d.cancelJob(p.JobID) {
|
||||
slog.Info("ws agent: command.cancel applied", "job_id", p.JobID)
|
||||
} else {
|
||||
// Job already finished or was never seen on this agent.
|
||||
// Not an error — operator may have raced cancel against
|
||||
// natural completion. Server-side state is authoritative.
|
||||
slog.Info("ws agent: command.cancel for unknown job (already finished?)", "job_id", p.JobID)
|
||||
}
|
||||
|
||||
case api.MsgTreeList:
|
||||
// Synchronous RPC for the restore wizard's tree browser. The
|
||||
// server has serialised access; we just run restic ls and reply
|
||||
// with the same envelope ID. Run in a goroutine so the WS read
|
||||
// loop keeps draining.
|
||||
var p api.TreeListRequestPayload
|
||||
if err := env.UnmarshalPayload(&p); err != nil {
|
||||
return fmt.Errorf("tree.list: %w", err)
|
||||
}
|
||||
go d.handleTreeList(ctx, env.ID, p, tx)
|
||||
// TODO(P2): cancellation requires keeping a job→cancelFunc map.
|
||||
slog.Info("ws agent: command.cancel received (cancellation lands in P2)", "id", env.ID)
|
||||
|
||||
case api.MsgScheduleSet:
|
||||
var p api.ScheduleSetPayload
|
||||
@@ -377,24 +263,6 @@ func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.S
|
||||
slog.Warn("ws agent: unknown config.update slot, ignoring", "slot", p.Slot)
|
||||
}
|
||||
|
||||
// Bandwidth caps ride independently of the slot — they're host-
|
||||
// wide and apply to every restic invocation regardless of which
|
||||
// credentials slot the job uses. nil pointer = no change in this
|
||||
// push; non-nil = set to that value (≤0 clears the cap).
|
||||
if p.BandwidthUpKBps != nil || p.BandwidthDownKBps != nil {
|
||||
d.bwMu.Lock()
|
||||
if p.BandwidthUpKBps != nil {
|
||||
d.bwUpKBps = *p.BandwidthUpKBps
|
||||
}
|
||||
if p.BandwidthDownKBps != nil {
|
||||
d.bwDownKBps = *p.BandwidthDownKBps
|
||||
}
|
||||
up, down := d.bwUpKBps, d.bwDownKBps
|
||||
d.bwMu.Unlock()
|
||||
slog.Info("ws agent: bandwidth caps updated",
|
||||
"up_kbps", up, "down_kbps", down)
|
||||
}
|
||||
|
||||
case api.MsgAgentUpdateAvail:
|
||||
var p api.AgentUpdateAvailablePayload
|
||||
_ = env.UnmarshalPayload(&p)
|
||||
@@ -406,113 +274,17 @@ func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.S
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTreeList runs `restic ls --json <snapshot> <path>` and ships
|
||||
// the matching tree.list.result envelope back, correlated by the
|
||||
// request envelope's ID. Errors (missing creds, restic failure)
|
||||
// surface in the result's Error field rather than as transport-level
|
||||
// failures so the server-side waiter can render a sensible message.
|
||||
func (d *dispatcher) handleTreeList(ctx context.Context, reqID string, p api.TreeListRequestPayload, tx wsclient.Sender) {
|
||||
reply := func(result api.TreeListResultPayload) {
|
||||
result.SnapshotID = p.SnapshotID
|
||||
result.Path = p.Path
|
||||
env, err := api.Marshal(api.MsgTreeListResult, reqID, result)
|
||||
if err != nil {
|
||||
slog.Warn("ws agent: marshal tree.list.result", "err", err)
|
||||
return
|
||||
}
|
||||
_ = tx.Send(env)
|
||||
}
|
||||
|
||||
if d.resticBin == "" {
|
||||
reply(api.TreeListResultPayload{Error: "restic binary not located on this agent"})
|
||||
return
|
||||
}
|
||||
creds, err := d.secrets.Load()
|
||||
if err != nil {
|
||||
reply(api.TreeListResultPayload{Error: "load credentials: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if creds.Empty() {
|
||||
reply(api.TreeListResultPayload{Error: "repo credentials not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
d.bwMu.Lock()
|
||||
upKBps, downKBps := d.bwUpKBps, d.bwDownKBps
|
||||
d.bwMu.Unlock()
|
||||
|
||||
env := restic.Env{
|
||||
Bin: d.resticBin,
|
||||
RepoURL: creds.URL,
|
||||
RepoUsername: creds.Username,
|
||||
RepoPassword: creds.Password,
|
||||
LimitUploadKBps: upKBps,
|
||||
LimitDownloadKBps: downKBps,
|
||||
}
|
||||
|
||||
// 60s ceiling matches snapshots/stats — restic ls on a single
|
||||
// directory is normally sub-second; if the repo is unreachable we
|
||||
// want to surface the failure rather than block the wizard.
|
||||
listCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
entries, err := env.ListTreeChildren(listCtx, p.SnapshotID, p.Path)
|
||||
if err != nil {
|
||||
reply(api.TreeListResultPayload{Error: err.Error()})
|
||||
return
|
||||
}
|
||||
apiEntries := make([]api.TreeListEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
apiEntries = append(apiEntries, api.TreeListEntry{
|
||||
Name: e.Name,
|
||||
Type: e.Type,
|
||||
Size: e.Size,
|
||||
})
|
||||
}
|
||||
reply(api.TreeListResultPayload{Entries: apiEntries})
|
||||
}
|
||||
|
||||
// failJob ships a synthetic job.started + job.finished(failed) pair
|
||||
// for a command.run we couldn't even spawn locally — missing restic
|
||||
// binary, missing credentials, or a malformed payload. Without these
|
||||
// envelopes the server has no way to know the job will never produce
|
||||
// output: the row sits in "running", the live stream stays stuck on
|
||||
// "awaiting agent output," and a subsequent command.cancel arrives
|
||||
// for a job_id the agent never registered (we log "unknown job"
|
||||
// because trackJob was never called). Sending a terminal envelope
|
||||
// here closes the loop on both fronts.
|
||||
func failJob(p api.CommandRunPayload, tx wsclient.Sender, errMsg string) {
|
||||
now := time.Now().UTC()
|
||||
if startedEnv, err := api.Marshal(api.MsgJobStarted, p.JobID, api.JobStartedPayload{
|
||||
JobID: p.JobID, Kind: p.Kind, StartedAt: now,
|
||||
}); err == nil {
|
||||
_ = tx.Send(startedEnv)
|
||||
}
|
||||
if finEnv, err := api.Marshal(api.MsgJobFinished, p.JobID, api.JobFinishedPayload{
|
||||
JobID: p.JobID,
|
||||
Status: api.JobFailed,
|
||||
ExitCode: -1,
|
||||
FinishedAt: now,
|
||||
Error: errMsg,
|
||||
}); err == nil {
|
||||
_ = tx.Send(finEnv)
|
||||
}
|
||||
}
|
||||
|
||||
// runJob spawns a runner for one job. We launch a goroutine so the
|
||||
// WS read loop keeps draining messages while restic chugs along.
|
||||
func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsclient.Sender) error {
|
||||
if d.resticBin == "" {
|
||||
failJob(p, tx, "restic binary not located on this agent")
|
||||
return fmt.Errorf("restic binary not located on this agent")
|
||||
}
|
||||
creds, err := d.secrets.Load()
|
||||
if err != nil {
|
||||
failJob(p, tx, "load repo credentials: "+err.Error())
|
||||
return fmt.Errorf("load repo credentials: %w", err)
|
||||
}
|
||||
if creds.Empty() {
|
||||
failJob(p, tx, "repo credentials not configured (waiting for server config.update push)")
|
||||
return fmt.Errorf("repo credentials not configured (waiting for server config.update push)")
|
||||
}
|
||||
// r is the everyday runner — bound to the host's repo
|
||||
@@ -523,48 +295,13 @@ func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsc
|
||||
// not on r). If you find yourself adding a new JobKind that
|
||||
// needs delete authority, mirror the JobPrune pattern below
|
||||
// — don't try to overload r.
|
||||
// Resolve bandwidth caps: per-job override (if set) wins over the
|
||||
// host-wide caps last pushed via config.update. <=0 means no cap.
|
||||
d.bwMu.Lock()
|
||||
upKBps, downKBps := d.bwUpKBps, d.bwDownKBps
|
||||
d.bwMu.Unlock()
|
||||
if p.BandwidthUpKBps != nil {
|
||||
upKBps = *p.BandwidthUpKBps
|
||||
}
|
||||
if p.BandwidthDownKBps != nil {
|
||||
downKBps = *p.BandwidthDownKBps
|
||||
}
|
||||
|
||||
r := runner.New(runner.Config{
|
||||
ResticBin: d.resticBin,
|
||||
ResticVersion: d.resticVer,
|
||||
RepoURL: creds.URL,
|
||||
RepoUsername: creds.Username,
|
||||
RepoPassword: creds.Password,
|
||||
SupportsRestoreNoOwnership: d.resticSupportsNoOwnership,
|
||||
LimitUploadKBps: upKBps,
|
||||
LimitDownloadKBps: downKBps,
|
||||
ResticBin: d.resticBin,
|
||||
RepoURL: creds.URL,
|
||||
RepoUsername: creds.Username,
|
||||
RepoPassword: creds.Password,
|
||||
}, tx, time.Second)
|
||||
|
||||
// spawn wraps the kind-specific goroutine: derives a per-job
|
||||
// cancellable context from the connection-scoped ctx, registers
|
||||
// the cancel func so command.cancel can fire it, deregisters on
|
||||
// completion. Per-job ctx means canceling one job doesn't kill
|
||||
// any other in-flight invocations.
|
||||
spawn := func(name string, fn func(ctx context.Context) error) {
|
||||
jobCtx, cancel := context.WithCancel(ctx)
|
||||
cleanup := d.trackJob(p.JobID, cancel)
|
||||
go func() {
|
||||
defer cleanup()
|
||||
defer cancel() // release ctx resources on goroutine exit
|
||||
if err := fn(jobCtx); err != nil {
|
||||
slog.Warn("agent: "+name+" job failed", "job_id", p.JobID, "err", err)
|
||||
return
|
||||
}
|
||||
slog.Info("agent: "+name+" job complete", "job_id", p.JobID)
|
||||
}()
|
||||
}
|
||||
|
||||
switch p.Kind {
|
||||
case api.JobBackup:
|
||||
// Includes/Excludes/Tag come from the source group resolved
|
||||
@@ -581,15 +318,22 @@ func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsc
|
||||
}
|
||||
slog.Info("agent: accepting backup job",
|
||||
"job_id", p.JobID, "paths", paths, "excludes", p.Excludes, "tag", p.Tag)
|
||||
hooks := runner.BackupHooks{Pre: p.PreHook, Post: p.PostHook}
|
||||
spawn("backup", func(jobCtx context.Context) error {
|
||||
return r.RunBackup(jobCtx, p.JobID, paths, p.Excludes, tags, hooks)
|
||||
})
|
||||
go func() {
|
||||
if err := r.RunBackup(ctx, p.JobID, paths, p.Excludes, tags); err != nil {
|
||||
slog.Warn("agent: backup job failed", "job_id", p.JobID, "err", err)
|
||||
return
|
||||
}
|
||||
slog.Info("agent: backup job complete", "job_id", p.JobID)
|
||||
}()
|
||||
case api.JobInit:
|
||||
slog.Info("agent: accepting init job", "job_id", p.JobID)
|
||||
spawn("init", func(jobCtx context.Context) error {
|
||||
return r.RunInit(jobCtx, p.JobID)
|
||||
})
|
||||
go func() {
|
||||
if err := r.RunInit(ctx, p.JobID); err != nil {
|
||||
slog.Warn("agent: init job failed", "job_id", p.JobID, "err", err)
|
||||
return
|
||||
}
|
||||
slog.Info("agent: init job complete", "job_id", p.JobID)
|
||||
}()
|
||||
case api.JobForget:
|
||||
if len(p.ForgetGroups) == 0 {
|
||||
// Hard-error rather than fall back to a single-policy form:
|
||||
@@ -599,7 +343,6 @@ func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsc
|
||||
// policy fallback was specced but skipped — see the
|
||||
// Phase 5 plan rationale and version.go's lockstep-deploy
|
||||
// note for why.
|
||||
failJob(p, tx, "forget: command.run carried no forget_groups (server didn't populate them)")
|
||||
return fmt.Errorf("forget: command.run carried no forget_groups (server didn't populate them)")
|
||||
}
|
||||
groups := make([]restic.ForgetGroup, 0, len(p.ForgetGroups))
|
||||
@@ -617,9 +360,13 @@ func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsc
|
||||
})
|
||||
}
|
||||
slog.Info("agent: accepting forget job", "job_id", p.JobID, "groups", len(groups))
|
||||
spawn("forget", func(jobCtx context.Context) error {
|
||||
return r.RunForget(jobCtx, p.JobID, groups)
|
||||
})
|
||||
go func() {
|
||||
if err := r.RunForget(ctx, p.JobID, groups); err != nil {
|
||||
slog.Warn("agent: forget job failed", "job_id", p.JobID, "err", err)
|
||||
return
|
||||
}
|
||||
slog.Info("agent: forget job complete", "job_id", p.JobID)
|
||||
}()
|
||||
case api.JobPrune:
|
||||
// Prune may require admin creds (delete authority on rest-server).
|
||||
runCreds := creds
|
||||
@@ -634,66 +381,36 @@ func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsc
|
||||
runCreds = ac
|
||||
}
|
||||
prr := runner.New(runner.Config{
|
||||
ResticBin: d.resticBin,
|
||||
ResticVersion: d.resticVer,
|
||||
RepoURL: runCreds.URL,
|
||||
RepoUsername: runCreds.Username,
|
||||
RepoPassword: runCreds.Password,
|
||||
SupportsRestoreNoOwnership: d.resticSupportsNoOwnership,
|
||||
LimitUploadKBps: upKBps,
|
||||
LimitDownloadKBps: downKBps,
|
||||
ResticBin: d.resticBin,
|
||||
RepoURL: runCreds.URL,
|
||||
RepoUsername: runCreds.Username,
|
||||
RepoPassword: runCreds.Password,
|
||||
}, tx, time.Second)
|
||||
slog.Info("agent: accepting prune job", "job_id", p.JobID, "admin_creds", p.RequiresAdminCreds)
|
||||
spawn("prune", func(jobCtx context.Context) error {
|
||||
return prr.RunPrune(jobCtx, p.JobID)
|
||||
})
|
||||
go func() {
|
||||
if err := prr.RunPrune(ctx, p.JobID); err != nil {
|
||||
slog.Warn("agent: prune job failed", "job_id", p.JobID, "err", err)
|
||||
}
|
||||
}()
|
||||
case api.JobCheck:
|
||||
subset := 0
|
||||
if len(p.Args) > 0 {
|
||||
subset, _ = strconv.Atoi(p.Args[0])
|
||||
}
|
||||
slog.Info("agent: accepting check job", "job_id", p.JobID, "subset_pct", subset)
|
||||
spawn("check", func(jobCtx context.Context) error {
|
||||
return r.RunCheck(jobCtx, p.JobID, subset)
|
||||
})
|
||||
go func() {
|
||||
if err := r.RunCheck(ctx, p.JobID, subset); err != nil {
|
||||
slog.Warn("agent: check job failed", "job_id", p.JobID, "err", err)
|
||||
}
|
||||
}()
|
||||
case api.JobUnlock:
|
||||
slog.Info("agent: accepting unlock job", "job_id", p.JobID)
|
||||
spawn("unlock", func(jobCtx context.Context) error {
|
||||
return r.RunUnlock(jobCtx, p.JobID)
|
||||
})
|
||||
case api.JobRestore:
|
||||
if p.Restore == nil {
|
||||
failJob(p, tx, "restore: command.run carried no restore payload")
|
||||
return fmt.Errorf("restore: command.run carried no restore payload")
|
||||
}
|
||||
rp := *p.Restore
|
||||
if rp.SnapshotID == "" {
|
||||
failJob(p, tx, "restore: snapshot_id is required")
|
||||
return fmt.Errorf("restore: snapshot_id is required")
|
||||
}
|
||||
if !rp.InPlace && rp.TargetDir == "" {
|
||||
failJob(p, tx, "restore: target_dir required for non-in-place restore")
|
||||
return fmt.Errorf("restore: target_dir required for non-in-place restore")
|
||||
}
|
||||
slog.Info("agent: accepting restore job",
|
||||
"job_id", p.JobID, "snapshot_id", rp.SnapshotID,
|
||||
"paths", rp.Paths, "in_place", rp.InPlace, "target", rp.TargetDir)
|
||||
spawn("restore", func(jobCtx context.Context) error {
|
||||
return r.RunRestore(jobCtx, p.JobID, rp.SnapshotID, rp.Paths, rp.InPlace, rp.TargetDir)
|
||||
})
|
||||
case api.JobDiff:
|
||||
if p.Diff == nil || p.Diff.SnapshotA == "" || p.Diff.SnapshotB == "" {
|
||||
failJob(p, tx, "diff: command.run carried incomplete diff payload")
|
||||
return fmt.Errorf("diff: command.run carried incomplete diff payload")
|
||||
}
|
||||
dp := *p.Diff
|
||||
slog.Info("agent: accepting diff job",
|
||||
"job_id", p.JobID, "a", dp.SnapshotA, "b", dp.SnapshotB)
|
||||
spawn("diff", func(jobCtx context.Context) error {
|
||||
return r.RunDiff(jobCtx, p.JobID, dp.SnapshotA, dp.SnapshotB)
|
||||
})
|
||||
go func() {
|
||||
if err := r.RunUnlock(ctx, p.JobID); err != nil {
|
||||
slog.Warn("agent: unlock job failed", "job_id", p.JobID, "err", err)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
failJob(p, tx, fmt.Sprintf("kind %q not implemented on this agent", p.Kind))
|
||||
return fmt.Errorf("kind %q not implemented yet (Phase 2 lands the rest)", p.Kind)
|
||||
}
|
||||
return nil
|
||||
|
||||
+11
-48
@@ -12,24 +12,17 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/alert"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
rmhttp "gitea.dcglab.co.uk/steve/restic-manager/internal/server/http"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/maintenance"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ui"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
date = "unknown"
|
||||
)
|
||||
var version = "dev"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -44,7 +37,7 @@ func run() error {
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("restic-manager-server %s (commit %s, built %s)\n", version, commit, date)
|
||||
fmt.Println("restic-manager-server", version)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,36 +82,19 @@ func run() error {
|
||||
hub := ws.NewHub()
|
||||
jobHub := ws.NewJobHub()
|
||||
|
||||
notifHub := notification.NewHub(st, aead, cfg.BaseURL)
|
||||
alertEngine := alert.NewEngine(st, notifHub)
|
||||
|
||||
renderer, err := ui.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ui: %w", err)
|
||||
}
|
||||
|
||||
var oidcClient *oidc.Client
|
||||
if cfg.OIDC != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
oidcClient, err = oidc.New(ctx, cfg.OIDC, cfg.BaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: %w", err)
|
||||
}
|
||||
slog.Info("oidc enabled", "issuer", cfg.OIDC.Issuer, "display", cfg.OIDC.DisplayName)
|
||||
}
|
||||
|
||||
deps := rmhttp.Deps{
|
||||
Cfg: cfg,
|
||||
Store: st,
|
||||
AEAD: aead,
|
||||
Hub: hub,
|
||||
JobHub: jobHub,
|
||||
AlertEngine: alertEngine,
|
||||
NotificationHub: notifHub,
|
||||
UI: renderer,
|
||||
Version: version,
|
||||
OIDC: oidcClient,
|
||||
Cfg: cfg,
|
||||
Store: st,
|
||||
AEAD: aead,
|
||||
Hub: hub,
|
||||
JobHub: jobHub,
|
||||
UI: renderer,
|
||||
Version: version,
|
||||
}
|
||||
|
||||
// First-run bootstrap: if the users table is empty, mint a one-time
|
||||
@@ -150,8 +126,6 @@ func run() error {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
go alertEngine.Run(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
slog.Info("server listening", "addr", cfg.Listen, "version", version)
|
||||
@@ -182,10 +156,6 @@ func run() error {
|
||||
// shouldn't, but the queue exists either way).
|
||||
pendingDrainTick := time.NewTicker(30 * time.Second)
|
||||
defer pendingDrainTick.Stop()
|
||||
// Pending-hosts expiry sweeper: drops announce rows past their 1h
|
||||
// ceiling so the dashboard panel doesn't accumulate stale entries.
|
||||
pendingExpiryTick := time.NewTicker(60 * time.Second)
|
||||
defer pendingExpiryTick.Stop()
|
||||
mt := maintenance.New(st)
|
||||
go func() {
|
||||
for {
|
||||
@@ -201,18 +171,11 @@ func run() error {
|
||||
}
|
||||
case <-offlineTick.C:
|
||||
cutoff := time.Now().Add(-90 * time.Second)
|
||||
if ids, err := st.MarkHostsOfflineStaleReturnIDs(ctx, cutoff); err == nil && len(ids) > 0 {
|
||||
slog.Info("marked hosts offline (stale heartbeat)", "n", len(ids))
|
||||
for _, id := range ids {
|
||||
alertEngine.NotifyHostOffline(id)
|
||||
}
|
||||
if n, err := st.MarkHostsOfflineStale(ctx, cutoff); err == nil && n > 0 {
|
||||
slog.Info("marked hosts offline (stale heartbeat)", "n", n)
|
||||
}
|
||||
case <-pendingDrainTick.C:
|
||||
srv.DrainAllDue(ctx)
|
||||
case <-pendingExpiryTick.C:
|
||||
if n, err := st.DeleteExpiredPendingHosts(ctx, time.Now().UTC()); err == nil && n > 0 {
|
||||
slog.Info("expired pending hosts swept", "n", n)
|
||||
}
|
||||
case <-maintenanceTick.C:
|
||||
decisions, err := mt.Decide(ctx, time.Now().UTC())
|
||||
if err != nil {
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
# syntax=docker/dockerfile:1.7
|
||||
|
||||
# ---- Build stage --------------------------------------------------------
|
||||
# Cross-compiles:
|
||||
# * the server binary for the image's TARGETARCH (linux/amd64 or arm64),
|
||||
# * three agent binaries (linux/amd64, linux/arm64, windows/amd64) that
|
||||
# the running server hands out via /agent/binary.
|
||||
# Pure-Go SQLite (modernc.org/sqlite) means CGO stays off; static binaries
|
||||
# run on distroless/static.
|
||||
FROM --platform=$BUILDPLATFORM golang:1.25-alpine AS build
|
||||
FROM golang:1.25-alpine AS build
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Pure-Go SQLite (modernc.org/sqlite) means we can keep CGO off and build a
|
||||
# fully static binary that runs on distroless/static.
|
||||
ENV CGO_ENABLED=0 \
|
||||
GOOS=linux \
|
||||
GOFLAGS="-trimpath"
|
||||
|
||||
# Cache module downloads in a separate layer.
|
||||
@@ -21,41 +18,9 @@ RUN go mod download
|
||||
COPY . .
|
||||
|
||||
ARG VERSION=dev
|
||||
ARG COMMIT=none
|
||||
ARG DATE=unknown
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV LDFLAGS="-s -w -X main.version=${VERSION} -X main.commit=${COMMIT} -X main.date=${DATE}"
|
||||
|
||||
# Server: built for the image's runtime arch.
|
||||
RUN GOOS=${TARGETOS} GOARCH=${TARGETARCH} \
|
||||
go build -ldflags="${LDFLAGS}" \
|
||||
-o /out/restic-manager-server \
|
||||
./cmd/server
|
||||
|
||||
# Empty /data skeleton so the runtime image carries an existing,
|
||||
# nonroot-owned mount point. Docker copies that ownership onto a
|
||||
# named volume the first time it's created, which avoids the
|
||||
# "permission denied" trap on /data/secret.key when the operator
|
||||
# uses a default `volumes: { rm-data: {} }` declaration.
|
||||
RUN mkdir -p /out/data
|
||||
|
||||
# Agents: identical across image arches — an arm64 server image still
|
||||
# ships an amd64 agent binary for amd64 endpoints to download.
|
||||
RUN mkdir -p /out/agent-binaries && \
|
||||
GOOS=linux GOARCH=amd64 \
|
||||
go build -ldflags="${LDFLAGS}" \
|
||||
-o /out/agent-binaries/restic-manager-agent-linux-amd64 \
|
||||
./cmd/agent && \
|
||||
GOOS=linux GOARCH=arm64 \
|
||||
go build -ldflags="${LDFLAGS}" \
|
||||
-o /out/agent-binaries/restic-manager-agent-linux-arm64 \
|
||||
./cmd/agent && \
|
||||
GOOS=windows GOARCH=amd64 \
|
||||
go build -ldflags="${LDFLAGS}" \
|
||||
-o /out/agent-binaries/restic-manager-agent-windows-amd64.exe \
|
||||
./cmd/agent
|
||||
RUN go build -ldflags="-s -w -X main.version=${VERSION}" \
|
||||
-o /out/restic-manager-server \
|
||||
./cmd/server
|
||||
|
||||
# ---- Runtime stage ------------------------------------------------------
|
||||
FROM gcr.io/distroless/static-debian12:nonroot
|
||||
@@ -66,22 +31,7 @@ LABEL org.opencontainers.image.licenses="PolyForm-Noncommercial-1.0.0"
|
||||
USER nonroot:nonroot
|
||||
WORKDIR /
|
||||
|
||||
# Server binary on PATH.
|
||||
COPY --from=build /out/restic-manager-server /usr/local/bin/restic-manager-server
|
||||
|
||||
# Image-baked bundled assets (P5-03). Read-only; the /agent/binary and
|
||||
# /install/* handlers fall back here when <DataDir>/... is empty, so a
|
||||
# fresh container Just Works without first-run staging. Operators can
|
||||
# still drop a custom build under <DataDir>/agent-binaries/<name> to
|
||||
# override per-host.
|
||||
COPY --from=build --chmod=0755 /out/agent-binaries/ /opt/restic-manager/dist/agent-binaries/
|
||||
COPY --chmod=0755 deploy/install/install.sh /opt/restic-manager/dist/install/install.sh
|
||||
COPY --chmod=0644 deploy/install/install.ps1 /opt/restic-manager/dist/install/install.ps1
|
||||
COPY --chmod=0644 deploy/install/restic-manager-agent.service /opt/restic-manager/dist/install/restic-manager-agent.service
|
||||
|
||||
# Pre-created data dir owned by nonroot so a fresh named volume
|
||||
# inherits the right ownership.
|
||||
COPY --from=build --chown=nonroot:nonroot /out/data /data
|
||||
|
||||
EXPOSE 8443
|
||||
ENTRYPOINT ["/usr/local/bin/restic-manager-server"]
|
||||
|
||||
@@ -1,52 +1,21 @@
|
||||
# Reference deployment for the restic-manager control plane.
|
||||
# Mirrors spec.md §10.1 and the P5-07 reference deployment.
|
||||
# Mirrors spec.md §10.1. Adjust image tag and RM_BASE_URL for your env.
|
||||
#
|
||||
# Scope: this compose stands up the server only. TLS termination and
|
||||
# the public hostname belong to a reverse proxy that lives outside
|
||||
# this stack (Caddy, Traefik, nginx, HAProxy, your existing edge —
|
||||
# whatever you already operate). See `docs/reverse-proxy.md` for the
|
||||
# headers + CIDRs that proxy needs to forward.
|
||||
#
|
||||
# Architecture:
|
||||
# * The server speaks plain HTTP on :8080.
|
||||
# * The agent binaries + install scripts ship inside the image under
|
||||
# /opt/restic-manager/dist/, so /agent/binary and /install/*
|
||||
# serve out of the box without first-run staging.
|
||||
# * The named volume holds *only* operator state (sqlite,
|
||||
# secrets.enc, audit log, the AEAD key). Image upgrades replace
|
||||
# the agents/scripts; the volume is untouched.
|
||||
# * Pre-1.0 releases never publish :latest — pin to an exact
|
||||
# vX.Y.Z tag and bump deliberately.
|
||||
#
|
||||
# Before first start:
|
||||
# 1. Pick a version: export RM_VERSION=vX.Y.Z (or substitute below).
|
||||
# 2. Set RM_BASE_URL to the public HTTPS URL the external proxy
|
||||
# serves on.
|
||||
# 3. Set RM_TRUSTED_PROXY to the IP/CIDR the proxy connects from
|
||||
# (the X-Forwarded-* headers are honoured only when the immediate
|
||||
# peer matches one of these).
|
||||
|
||||
# The server speaks plain HTTP. Front it with a TLS-terminating
|
||||
# reverse proxy (Caddy/Traefik/nginx). RM_TRUSTED_PROXY must contain
|
||||
# the proxy's IP/CIDR so X-Forwarded-* headers are honoured.
|
||||
services:
|
||||
restic-manager:
|
||||
image: gitea.dcglab.co.uk/steve/restic-manager:${RM_VERSION:?set RM_VERSION to a vX.Y.Z tag}
|
||||
image: ghcr.io/dcglab/restic-manager:latest
|
||||
restart: unless-stopped
|
||||
# Bind to localhost only — your reverse proxy reaches the server
|
||||
# over loopback (or, if it runs in a separate compose / on
|
||||
# another host, swap this for an internal docker network or a
|
||||
# private LAN bind).
|
||||
# Bind to localhost only — the proxy is what the public reaches.
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080"
|
||||
volumes:
|
||||
- rm-data:/data
|
||||
- ./data:/data
|
||||
environment:
|
||||
- RM_DATA_DIR=/data
|
||||
- RM_LISTEN=:8080
|
||||
- RM_BASE_URL=${RM_BASE_URL:?set RM_BASE_URL to the public https URL}
|
||||
- RM_BASE_URL=https://restic.lab.example
|
||||
- RM_SECRET_KEY_FILE=/data/secret.key
|
||||
- RM_TRUSTED_PROXY=${RM_TRUSTED_PROXY:?set RM_TRUSTED_PROXY to the proxy CIDR}
|
||||
# Cookies are Secure by default; keep that. Override only for
|
||||
# local-HTTP smoke tests.
|
||||
# - RM_COOKIE_SECURE=true
|
||||
|
||||
volumes:
|
||||
rm-data:
|
||||
- RM_TRUSTED_PROXY=172.16.0.0/12
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# install.ps1 — Windows installer for the restic-manager agent (P2-17).
|
||||
#
|
||||
# Usage (Run as administrator):
|
||||
# $env:RM_SERVER = "https://restic.lab.example"
|
||||
# $env:RM_TOKEN = "<one-time-token>" # omit for announce-and-approve
|
||||
# iwr "$env:RM_SERVER/install/install.ps1" -UseBasicParsing | iex
|
||||
#
|
||||
# What it does:
|
||||
# 1. checks for admin elevation
|
||||
# 2. downloads the matching agent binary from the server
|
||||
# 3. lays down C:\Program Files\restic-manager\ and
|
||||
# C:\ProgramData\restic-manager\ (config + state)
|
||||
# 4. registers the agent as a Windows service via the agent's own
|
||||
# `install` subcommand (which uses the SCM API)
|
||||
# 5. enrolls (token flow if RM_TOKEN set, otherwise announce flow)
|
||||
# by spawning the agent with the right CLI flags and waits
|
||||
# until config is written
|
||||
# 6. surfaces (but does NOT disable) any existing scheduled tasks
|
||||
# whose name contains "restic" so the operator can decide
|
||||
#
|
||||
# Idempotent — safe to re-run.
|
||||
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[string]$Server = $env:RM_SERVER,
|
||||
[string]$Token = $env:RM_TOKEN,
|
||||
[string]$InstallDir = 'C:\Program Files\restic-manager',
|
||||
[string]$DataDir = 'C:\ProgramData\restic-manager'
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
function Test-Admin {
|
||||
$id = [System.Security.Principal.WindowsIdentity]::GetCurrent()
|
||||
$pri = New-Object System.Security.Principal.WindowsPrincipal($id)
|
||||
return $pri.IsInRole([System.Security.Principal.WindowsBuiltInRole]::Administrator)
|
||||
}
|
||||
|
||||
function Detect-Arch {
|
||||
switch ($env:PROCESSOR_ARCHITECTURE) {
|
||||
'AMD64' { return 'amd64' }
|
||||
'ARM64' { return 'arm64' }
|
||||
default { throw "unsupported PROCESSOR_ARCHITECTURE: $($env:PROCESSOR_ARCHITECTURE)" }
|
||||
}
|
||||
}
|
||||
|
||||
function Detect-ResticTasks {
|
||||
Write-Host ''
|
||||
Write-Host '— Existing restic-named scheduled tasks (review manually) —'
|
||||
try {
|
||||
$tasks = Get-ScheduledTask -ErrorAction SilentlyContinue |
|
||||
Where-Object { $_.TaskName -match 'restic' -or $_.TaskPath -match 'restic' }
|
||||
if ($tasks) {
|
||||
foreach ($t in $tasks) {
|
||||
Write-Host " * $($t.TaskPath)$($t.TaskName) state=$($t.State)"
|
||||
Write-Host " Disable with: Disable-ScheduledTask -TaskName '$($t.TaskName)' -TaskPath '$($t.TaskPath)'"
|
||||
}
|
||||
} else {
|
||||
Write-Host ' (none found)'
|
||||
}
|
||||
} catch {
|
||||
Write-Host ' (Get-ScheduledTask failed; review the Task Scheduler UI manually)'
|
||||
}
|
||||
Write-Host ''
|
||||
}
|
||||
|
||||
# --- preflight -------------------------------------------------------
|
||||
|
||||
if (-not (Test-Admin)) {
|
||||
throw 'install.ps1: must be run from an elevated PowerShell (Run as administrator).'
|
||||
}
|
||||
if (-not $Server) {
|
||||
throw 'install.ps1: -Server (or $env:RM_SERVER) is required, e.g. https://restic.lab.example'
|
||||
}
|
||||
|
||||
$arch = Detect-Arch
|
||||
Write-Host "install.ps1: server=$Server arch=$arch"
|
||||
|
||||
# --- directories -----------------------------------------------------
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $InstallDir | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path $DataDir | Out-Null
|
||||
|
||||
# --- download agent --------------------------------------------------
|
||||
|
||||
$agentExe = Join-Path $InstallDir 'restic-manager-agent.exe'
|
||||
$tmpExe = "$agentExe.tmp"
|
||||
$dlURL = "$Server/agent/binary?os=windows&arch=$arch"
|
||||
Write-Host "install.ps1: downloading $dlURL"
|
||||
Invoke-WebRequest -UseBasicParsing -Uri $dlURL -OutFile $tmpExe
|
||||
# Atomic-ish replace: stop service if running so the .exe isn't busy.
|
||||
try { Stop-Service -Name 'restic-manager-agent' -ErrorAction SilentlyContinue } catch {}
|
||||
Move-Item -Force -Path $tmpExe -Destination $agentExe
|
||||
|
||||
# --- enroll / announce -----------------------------------------------
|
||||
|
||||
$cfgPath = Join-Path $DataDir 'agent.yaml'
|
||||
$args = @('-config', $cfgPath, '-enroll-server', $Server)
|
||||
if ($Token) {
|
||||
$args += @('-enroll-token', $Token)
|
||||
Write-Host 'install.ps1: enrolling with one-time token'
|
||||
} else {
|
||||
Write-Host 'install.ps1: no RM_TOKEN — running announce-and-approve flow.'
|
||||
Write-Host ' The fingerprint will print below. Compare it with the dashboard before clicking Accept.'
|
||||
}
|
||||
& $agentExe @args
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "install.ps1: agent enrolment failed (exit $LASTEXITCODE)"
|
||||
}
|
||||
|
||||
# --- install + start service ----------------------------------------
|
||||
|
||||
# The 'install' subcommand registers the service via the SCM. If
|
||||
# already registered, it errors loudly — re-run with -Force only if
|
||||
# you've manually verified.
|
||||
try {
|
||||
& $agentExe install
|
||||
} catch {
|
||||
Write-Host "install.ps1: service may already be registered ($_); continuing."
|
||||
}
|
||||
try {
|
||||
Start-Service -Name 'restic-manager-agent'
|
||||
} catch {
|
||||
Write-Host "install.ps1: Start-Service failed ($_); check Event Viewer."
|
||||
}
|
||||
|
||||
Detect-ResticTasks
|
||||
|
||||
Write-Host ''
|
||||
Write-Host 'install.ps1: done.'
|
||||
Write-Host " config : $cfgPath"
|
||||
Write-Host " binary : $agentExe"
|
||||
Write-Host " service: restic-manager-agent (Get-Service to inspect)"
|
||||
@@ -49,11 +49,6 @@ detect_arch() {
|
||||
ensure_dirs() {
|
||||
install -d -m 0700 -o root -g root "$RM_CONFIG_DIR"
|
||||
install -d -m 0700 -o root -g root "$RM_STATE_DIR"
|
||||
# Default new-directory restore target: $HOME/rm-restore. With the
|
||||
# current unit (ProtectSystem=full, no ReadWritePaths pin) the agent
|
||||
# can mkdir anywhere on real filesystems, so this is just a courtesy
|
||||
# pre-create so the wizard's default lands in a tidy spot.
|
||||
install -d -m 0700 -o root -g root /root/rm-restore
|
||||
}
|
||||
|
||||
detect_existing_schedulers() {
|
||||
|
||||
@@ -33,26 +33,12 @@ CapabilityBoundingSet=CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE CAP_FOWNER CAP_CHOWN
|
||||
AmbientCapabilities=CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE CAP_FOWNER CAP_CHOWN
|
||||
|
||||
# Hardening — blocks privilege escalation even from root, and
|
||||
# confines kernel / namespace / privilege surface. Filesystem reads
|
||||
# stay open (that's the whole job) and restore writes are
|
||||
# unrestricted: a backup tool whose entire purpose is "put files
|
||||
# back where they belong" can't have ProtectHome=read-only or
|
||||
# ProtectSystem=strict without breaking on the first cross-user
|
||||
# restore. ProtectSystem=full keeps /usr, /boot, /efi read-only so a
|
||||
# compromised agent can't swap out /usr/bin/restic or drop a kernel
|
||||
# module, while leaving /home, /root, /var, /opt, /srv, /tmp etc.
|
||||
# writable for arbitrary restore targets. The agent is treated as a
|
||||
# high-trust component (it runs operator hooks as root and holds
|
||||
# repo credentials); the residual hardening is about kernel + privesc
|
||||
# protection, not write confinement.
|
||||
# confines writes / network / kernel access to what restic actually
|
||||
# needs. Filesystem reads stay open: that's the whole job.
|
||||
NoNewPrivileges=true
|
||||
ProtectSystem=full
|
||||
# ProtectSystem=full mounts /usr, /boot, /efi *and* /etc read-only.
|
||||
# The agent rewrites /etc/restic-manager/agent.yaml on enrolment and
|
||||
# whenever a new SecretsKey is minted, so we need a targeted
|
||||
# write-exemption for that dir. No exemption for the rest of /etc:
|
||||
# the agent has no business editing /etc/passwd, /etc/sudoers, etc.
|
||||
ReadWritePaths=/etc/restic-manager
|
||||
ProtectSystem=strict
|
||||
ReadWritePaths=/etc/restic-manager /var/lib/restic-manager
|
||||
ProtectHome=read-only
|
||||
ProtectHostname=true
|
||||
ProtectKernelTunables=true
|
||||
ProtectKernelModules=true
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
# Running behind a reverse proxy
|
||||
|
||||
The restic-manager server is HTTP-only by design (see `spec.md` §11):
|
||||
TLS termination, public hostname, ACME, HSTS, and edge-level rate
|
||||
limiting all belong to a reverse proxy that you already operate
|
||||
outside this project. The reference compose in `deploy/docker-compose.yml`
|
||||
stands up *only* the server; this page covers what your proxy needs
|
||||
to do to make the rest of it work.
|
||||
|
||||
## What the proxy must forward
|
||||
|
||||
The server reads four headers when (and only when) the immediate peer
|
||||
matches `RM_TRUSTED_PROXY`:
|
||||
|
||||
| Header | Value | Why |
|
||||
|---------------------|----------------------------------------------------------|-----|
|
||||
| `X-Forwarded-For` | The original client IP (single value, or comma chain) | Rate-limit keys, audit log entries, and OIDC redirect-URI checks all use the real client IP. |
|
||||
| `X-Forwarded-Proto` | `https` | The server emits absolute URLs (e.g. OIDC redirect URIs) using this. |
|
||||
| `Host` | The public hostname clients use | Cookies are scoped to this; `RM_BASE_URL` must match. |
|
||||
| `Connection`/`Upgrade` | Pass through unchanged | The agent connects on `/ws/agent` and the live-log viewer connects on `/api/jobs/{id}/stream` — both are WebSockets and need `Upgrade: websocket` to survive the hop. |
|
||||
|
||||
Set `RM_TRUSTED_PROXY` to the CIDR (or comma-separated list of CIDRs)
|
||||
the proxy connects from. Anything outside that range has its
|
||||
`X-Forwarded-*` headers ignored, so a stray request that bypasses the
|
||||
proxy can't spoof the client IP.
|
||||
|
||||
## Example: Caddy
|
||||
|
||||
```caddyfile
|
||||
restic.example.com {
|
||||
# Caddy's default reverse_proxy preserves Host, sets
|
||||
# X-Forwarded-For/Proto, and passes Connection: upgrade through,
|
||||
# so a single directive covers HTTP + WebSocket.
|
||||
reverse_proxy 127.0.0.1:8080
|
||||
|
||||
encode zstd gzip
|
||||
}
|
||||
```
|
||||
|
||||
`RM_TRUSTED_PROXY=127.0.0.1/32` if Caddy and the server share the
|
||||
host; the docker-bridge CIDR (commonly `172.16.0.0/12`) if Caddy
|
||||
runs in another container on the default bridge network.
|
||||
|
||||
## Example: nginx
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name restic.example.com;
|
||||
|
||||
ssl_certificate /etc/ssl/restic.example.com.fullchain.pem;
|
||||
ssl_certificate_key /etc/ssl/restic.example.com.key.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# WebSocket support — agent + live-log endpoints need this.
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
|
||||
# Trusted-proxy headers.
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto https;
|
||||
|
||||
# Live job logs are long-running streams. Bump read timeouts
|
||||
# so nginx doesn't drop them mid-backup.
|
||||
proxy_read_timeout 1h;
|
||||
proxy_send_timeout 1h;
|
||||
}
|
||||
}
|
||||
|
||||
# Standard websocket upgrade map (define once at the http {} level).
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
```
|
||||
|
||||
`RM_TRUSTED_PROXY` for the same-host case: `127.0.0.1/32`.
|
||||
|
||||
## Example: Traefik (label-based)
|
||||
|
||||
```yaml
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.restic-manager.rule=Host(`restic.example.com`)"
|
||||
- "traefik.http.routers.restic-manager.entrypoints=websecure"
|
||||
- "traefik.http.routers.restic-manager.tls.certresolver=letsencrypt"
|
||||
- "traefik.http.services.restic-manager.loadbalancer.server.port=8080"
|
||||
```
|
||||
|
||||
Traefik handles `X-Forwarded-*` and `Connection: upgrade` by default.
|
||||
`RM_TRUSTED_PROXY` should be the docker network the Traefik container
|
||||
shares with the server (commonly `172.16.0.0/12` for the default
|
||||
bridge, or whatever your overlay network's CIDR is).
|
||||
|
||||
## Sanity-checking the wiring
|
||||
|
||||
After bringing the stack up:
|
||||
|
||||
1. `curl -fsS https://restic.example.com/healthz` — should return 200.
|
||||
2. The login page should report HTTPS in the address bar; cookies
|
||||
set after login should carry the `Secure` flag.
|
||||
3. Check the server log for the `config resolved` line:
|
||||
`trusted_proxies` must include the IP/CIDR your proxy actually
|
||||
connects from.
|
||||
4. Enrol a test agent — the WebSocket handshake hitting `/ws/agent`
|
||||
confirms `Upgrade` is being forwarded correctly.
|
||||
|
||||
If any of those fail, the proxy is the first place to look — the
|
||||
server itself is intentionally minimal.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,26 +3,22 @@ module gitea.dcglab.co.uk/steve/restic-manager
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coreos/go-oidc/v3 v3.18.0
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/oklog/ulid/v2 v2.1.1
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/sys v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.50.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/coder/websocket v1.8.14 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
modernc.org/libc v1.72.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -31,8 +25,6 @@ golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -62,13 +62,6 @@ type Config struct {
|
||||
LegacyRepoURL string `yaml:"repo_url,omitempty"`
|
||||
LegacyRepoPassword string `yaml:"repo_password,omitempty"`
|
||||
|
||||
// AnnounceKey is the base64-encoded Ed25519 private key used by
|
||||
// announce-and-approve enrolment (P2-18). Generated on first
|
||||
// announce, persisted so the agent can re-attach to the same
|
||||
// pending row across restarts. 64 bytes when decoded.
|
||||
// Empty for token-flow enrolments.
|
||||
AnnounceKey string `yaml:"announce_key,omitempty"`
|
||||
|
||||
// path is the file we loaded from. Used by Save.
|
||||
path string `yaml:"-"`
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// (fakeSender is defined in runner_test.go; it's already lock-protected
|
||||
// because the runner's stdout + stderr pump goroutines call Send
|
||||
// concurrently. The original local 'safeSender' here was a workaround
|
||||
// from before fakeSender itself grew the mutex.)
|
||||
|
||||
// TestRunBackupCanceledMidRunReportsCanceled spawns a backup against
|
||||
// a fake restic that sleeps for 30 seconds, cancels the context after
|
||||
// a short delay, and confirms the resulting job.finished envelope
|
||||
// reports status=canceled (not failed).
|
||||
func TestRunBackupCanceledMidRunReportsCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Fake restic: replace the shell with a long sleep via `exec` so the
|
||||
// process tree is one process — SIGTERM goes directly to sleep and
|
||||
// it exits. Without `exec`, the shell stays in the foreground while
|
||||
// sleep is its child; SIGTERM-to-shell may or may not propagate to
|
||||
// sleep depending on the shell, leading to the WaitDelay-then-
|
||||
// SIGKILL fallback path firing — slower and noisier.
|
||||
bin := setupScript(t, `exec sleep 30`)
|
||||
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- r.RunBackup(ctx, "job-cancel", []string{"/tmp/x"}, nil, nil, BackupHooks{})
|
||||
}()
|
||||
|
||||
// Wait long enough for the subprocess to actually start before
|
||||
// canceling. Without this, exec.CommandContext can race the
|
||||
// kill against Start and produce a different error path.
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("RunBackup did not return within 15s of cancel")
|
||||
}
|
||||
|
||||
// Locate the job.finished envelope and check its status.
|
||||
envs := tx.snapshot()
|
||||
var finEnv api.Envelope
|
||||
var found bool
|
||||
for _, e := range envs {
|
||||
if e.Type == api.MsgJobFinished {
|
||||
finEnv = e
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("no job.finished envelope was sent")
|
||||
}
|
||||
var fin api.JobFinishedPayload
|
||||
if err := finEnv.UnmarshalPayload(&fin); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if fin.Status != api.JobCancelled {
|
||||
t.Fatalf("status: got %q, want %q", fin.Status, api.JobCancelled)
|
||||
}
|
||||
if fin.ExitCode != 130 {
|
||||
t.Errorf("exit_code: got %d, want 130 (POSIX cancel convention)", fin.ExitCode)
|
||||
}
|
||||
// The error message should be empty for canceled jobs (see runner.sendFinished).
|
||||
if !strings.HasPrefix(fin.Error, "") || fin.Error != "" {
|
||||
t.Errorf("error: got %q, want empty for canceled jobs", fin.Error)
|
||||
}
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
// hooks.go — pre/post backup hooks for the agent runner (P2R-11).
|
||||
//
|
||||
// Hooks fire only for backup jobs (the runner's other kinds —
|
||||
// init/forget/prune/check/unlock — call shell scripts that touch
|
||||
// repo internals; running operator hooks for those would be
|
||||
// surprising). Hook bodies arrive plaintext on the wire (server
|
||||
// decrypted before the WS push). The agent never persists them
|
||||
// to disk; they live in memory for the lifetime of one job.
|
||||
//
|
||||
// Failure semantics:
|
||||
// - pre_hook non-zero exit aborts the backup: the runner returns
|
||||
// the error, the job is recorded as failed, and the actual
|
||||
// restic invocation never runs.
|
||||
// - post_hook non-zero exit is logged with a warning prefix in
|
||||
// the job log but does NOT change the job status — the operator
|
||||
// wants the backup result preserved even if the cleanup step
|
||||
// misbehaved.
|
||||
//
|
||||
// Streaming: each line of the hook's stdout/stderr is shipped as a
|
||||
// log.stream envelope with payload prefixed `hook: ` so the live
|
||||
// log viewer can visually separate it from restic's own output.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// runHook executes script via the host shell. status is the value
|
||||
// passed as RM_JOB_STATUS in the env (empty for pre-hooks; the
|
||||
// final job status — "succeeded" or "failed" — for post-hooks).
|
||||
// Returns an error iff the hook exited non-zero. ctx cancellation
|
||||
// kills the subprocess.
|
||||
func (r *Runner) runHook(ctx context.Context, jobID, phase, script, status string, seq *atomic.Int64) error {
|
||||
if script == "" {
|
||||
return nil
|
||||
}
|
||||
shell, flag := defaultShell()
|
||||
cmd := exec.CommandContext(ctx, shell, flag, script)
|
||||
cmd.Env = []string{
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
}
|
||||
if status != "" {
|
||||
cmd.Env = append(cmd.Env, "RM_JOB_STATUS="+status)
|
||||
}
|
||||
cmd.Env = append(cmd.Env, "RM_JOB_ID="+jobID, "RM_HOOK_PHASE="+phase)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("hook %s: stdout pipe: %w", phase, err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("hook %s: stderr pipe: %w", phase, err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("hook %s: start: %w", phase, err)
|
||||
}
|
||||
done := make(chan struct{}, 2)
|
||||
go func() { r.pumpHookLines(stdout, "stdout", phase, jobID, seq); done <- struct{}{} }()
|
||||
go func() { r.pumpHookLines(stderr, "stderr", phase, jobID, seq); done <- struct{}{} }()
|
||||
<-done
|
||||
<-done
|
||||
if werr := cmd.Wait(); werr != nil {
|
||||
return fmt.Errorf("hook %s exited non-zero: %w", phase, werr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pumpHookLines streams lines as log.stream envelopes prefixed with
|
||||
// "hook(<phase>): " so the live log can visually separate them.
|
||||
func (r *Runner) pumpHookLines(rd io.Reader, stream, phase, jobID string, seq *atomic.Int64) {
|
||||
scanner := bufio.NewScanner(rd)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := "hook(" + phase + "): " + scanner.Text()
|
||||
env, _ := api.Marshal(api.MsgLogStream, "", api.LogStreamLine{
|
||||
JobID: jobID,
|
||||
Seq: seq.Add(1),
|
||||
TS: time.Now().UTC(),
|
||||
Stream: api.LogStream(stream),
|
||||
Payload: line,
|
||||
})
|
||||
_ = r.tx.Send(env)
|
||||
}
|
||||
}
|
||||
|
||||
// defaultShell returns the (binary, single-arg-flag) pair to use for
|
||||
// `<shell> <flag> "<script>"`. /bin/sh -c on Unix; cmd.exe /C on
|
||||
// Windows. The hook author writes whichever shell they prefer
|
||||
// inside the script body itself (PowerShell, bash, etc) — this is
|
||||
// just the bootstrap interpreter.
|
||||
func defaultShell() (string, string) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return "cmd.exe", "/C"
|
||||
}
|
||||
return "/bin/sh", "-c"
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
// hooks_test.go — pre/post backup hook semantics (P2R-11).
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// TestPreHookFailureAbortsBackup: pre_hook exits 1 → restic never
|
||||
// runs, job is recorded failed with the hook's error.
|
||||
func TestPreHookFailureAbortsBackup(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Restic script that records every invocation. If restic was
|
||||
// called we'll see "restic-was-here" in the captured log.
|
||||
bin := setupScript(t, `echo "restic-was-here"`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
|
||||
err := r.RunBackup(context.Background(), "job-pre",
|
||||
[]string{"/etc"}, nil, []string{"tag"},
|
||||
BackupHooks{Pre: "exit 1"})
|
||||
if err == nil {
|
||||
t.Fatal("expected RunBackup to return an error from failed pre_hook")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "pre_hook failed") {
|
||||
t.Fatalf("error message: %q (want 'pre_hook failed')", err)
|
||||
}
|
||||
// job.finished arrived with status=failed.
|
||||
finEnv := firstEnvOfType(t, tx.envs, api.MsgJobFinished)
|
||||
var fin api.JobFinishedPayload
|
||||
_ = finEnv.UnmarshalPayload(&fin)
|
||||
if fin.Status != api.JobFailed {
|
||||
t.Fatalf("status: %q, want failed", fin.Status)
|
||||
}
|
||||
// restic must NOT have run.
|
||||
for _, env := range tx.envs {
|
||||
if env.Type != api.MsgLogStream {
|
||||
continue
|
||||
}
|
||||
var l api.LogStreamLine
|
||||
_ = env.UnmarshalPayload(&l)
|
||||
if strings.Contains(l.Payload, "restic-was-here") {
|
||||
t.Fatal("restic was invoked despite pre_hook failure")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostHookRunsAfterBackup: post_hook fires after a successful
|
||||
// backup and receives RM_JOB_STATUS=succeeded in the env.
|
||||
func TestPostHookRunsAfterBackup(t *testing.T) {
|
||||
t.Parallel()
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
backup) echo '{"message_type":"summary","snapshot_id":"abc"}' ;;
|
||||
snapshots) echo '[]' ;;
|
||||
stats) echo '{"total_size":0,"total_uncompressed_size":0,"snapshots_count":0,"total_file_count":0,"total_blob_count":0}' ;;
|
||||
*) exit 0 ;;
|
||||
esac
|
||||
`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
|
||||
post := `echo "post-status=$RM_JOB_STATUS phase=$RM_HOOK_PHASE"`
|
||||
if err := r.RunBackup(context.Background(), "job-post",
|
||||
[]string{"/etc"}, nil, nil, BackupHooks{Post: post}); err != nil {
|
||||
t.Fatalf("RunBackup: %v", err)
|
||||
}
|
||||
|
||||
// Walk log.stream envelopes; one of them should be the post-hook
|
||||
// line with the expected status.
|
||||
var found bool
|
||||
for _, env := range tx.envs {
|
||||
if env.Type != api.MsgLogStream {
|
||||
continue
|
||||
}
|
||||
var l api.LogStreamLine
|
||||
_ = env.UnmarshalPayload(&l)
|
||||
if strings.Contains(l.Payload, "post-status=succeeded") &&
|
||||
strings.Contains(l.Payload, "phase=post") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("post_hook output not found in log.stream envelopes")
|
||||
}
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// TestRunRestoreShipsExpectedEnvelopes: a fake restic emits a couple
|
||||
// of restore status lines and a summary; the runner translates them
|
||||
// into job.progress envelopes and finishes the job successfully.
|
||||
func TestRunRestoreShipsExpectedEnvelopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
restore)
|
||||
echo '{"message_type":"status","seconds_elapsed":1,"percent_done":0.5,"total_files":10,"files_restored":5,"total_bytes":1000,"bytes_restored":500}'
|
||||
echo '{"message_type":"status","seconds_elapsed":2,"percent_done":1.0,"total_files":10,"files_restored":10,"total_bytes":1000,"bytes_restored":1000}'
|
||||
echo '{"message_type":"summary","seconds_elapsed":2,"total_files":10,"files_restored":10,"total_bytes":1000,"bytes_restored":1000}'
|
||||
;;
|
||||
*)
|
||||
echo "unknown: $*" ;;
|
||||
esac
|
||||
`)
|
||||
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
|
||||
if err := r.RunRestore(context.Background(), "job-r1", "f3a7b2c1",
|
||||
[]string{"/etc/nginx/sites-available/alfa.conf"},
|
||||
false, "/tmp/restore-out"); err != nil {
|
||||
t.Fatalf("RunRestore: %v", err)
|
||||
}
|
||||
|
||||
// Confirm landmarks: started → progress → finished.
|
||||
order := envelopeOrder(tx.envs)
|
||||
wants := []api.MessageType{api.MsgJobStarted, api.MsgJobProgress, api.MsgJobFinished}
|
||||
positions := map[api.MessageType]int{}
|
||||
for i, mt := range order {
|
||||
if _, seen := positions[mt]; !seen {
|
||||
positions[mt] = i
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(wants)-1; i++ {
|
||||
a, b := wants[i], wants[i+1]
|
||||
pa, aOK := positions[a]
|
||||
pb, bOK := positions[b]
|
||||
if !aOK {
|
||||
t.Fatalf("envelope %q not found in %v", a, order)
|
||||
}
|
||||
if !bOK {
|
||||
t.Fatalf("envelope %q not found in %v", b, order)
|
||||
}
|
||||
if pa >= pb {
|
||||
t.Fatalf("expected %q before %q (positions %d, %d)", a, b, pa, pb)
|
||||
}
|
||||
}
|
||||
|
||||
// Started carries the right kind.
|
||||
startEnv := firstEnvOfType(t, tx.envs, api.MsgJobStarted)
|
||||
var startP api.JobStartedPayload
|
||||
if err := startEnv.UnmarshalPayload(&startP); err != nil {
|
||||
t.Fatalf("unmarshal started: %v", err)
|
||||
}
|
||||
if startP.Kind != api.JobRestore {
|
||||
t.Fatalf("kind: got %q want %q", startP.Kind, api.JobRestore)
|
||||
}
|
||||
|
||||
// Finished is succeeded.
|
||||
finEnv := firstEnvOfType(t, tx.envs, api.MsgJobFinished)
|
||||
var finP api.JobFinishedPayload
|
||||
if err := finEnv.UnmarshalPayload(&finP); err != nil {
|
||||
t.Fatalf("unmarshal finished: %v", err)
|
||||
}
|
||||
if finP.Status != api.JobSucceeded {
|
||||
t.Fatalf("status: got %q want %q", finP.Status, api.JobSucceeded)
|
||||
}
|
||||
// Progress envelope reflects the last status line: 100% with 10 files.
|
||||
progEnv := firstEnvOfType(t, tx.envs, api.MsgJobProgress)
|
||||
var progP api.JobProgressPayload
|
||||
if err := progEnv.UnmarshalPayload(&progP); err != nil {
|
||||
t.Fatalf("unmarshal progress: %v", err)
|
||||
}
|
||||
// First progress will be from line 1 (50%) since we send first status
|
||||
// immediately. Verify we at least see a sensible value.
|
||||
if progP.PercentDone <= 0 {
|
||||
t.Fatalf("expected non-zero progress, got %v", progP.PercentDone)
|
||||
}
|
||||
if progP.FilesDone <= 0 || progP.TotalFiles <= 0 {
|
||||
t.Fatalf("expected file counters set, got %+v", progP)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunRestoreInPlaceArgvHasNoNoOwnership: indirectly verifies that
|
||||
// in-place mode doesn't pass --no-ownership. We can't see the actual
|
||||
// argv without a custom test harness, so we use a fake restic that
|
||||
// echoes its args and check the captured log.stream.
|
||||
func TestRunRestoreInPlaceArgvHasNoNoOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
restore)
|
||||
# Print all args on stderr so they're forwarded as log.stream.
|
||||
echo "argv: $*" 1>&2
|
||||
echo '{"message_type":"summary","seconds_elapsed":0,"total_files":0,"files_restored":0,"total_bytes":0,"bytes_restored":0}'
|
||||
;;
|
||||
esac
|
||||
`)
|
||||
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
if err := r.RunRestore(context.Background(), "job-r2", "abc",
|
||||
nil, true, ""); err != nil {
|
||||
t.Fatalf("RunRestore: %v", err)
|
||||
}
|
||||
|
||||
// Reconstruct the argv from the captured stderr log line.
|
||||
var argv string
|
||||
for _, e := range tx.envs {
|
||||
if e.Type == api.MsgLogStream {
|
||||
var p api.LogStreamLine
|
||||
_ = e.UnmarshalPayload(&p)
|
||||
if p.Stream == api.LogStderr && strings.HasPrefix(p.Payload, "argv:") {
|
||||
argv = p.Payload
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if argv == "" {
|
||||
t.Fatal("never captured argv echo from fake restic")
|
||||
}
|
||||
if strings.Contains(argv, "--no-ownership") {
|
||||
t.Errorf("in-place restore should NOT pass --no-ownership; got argv=%q", argv)
|
||||
}
|
||||
if !strings.Contains(argv, "--target /") {
|
||||
t.Errorf("in-place restore should pass --target /; got argv=%q", argv)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunRestoreNewDirArgvShape: non-in-place restore passes --target
|
||||
// to the operator-chosen new directory and includes the path filters.
|
||||
// We deliberately do NOT pass --no-ownership (added in restic 0.17;
|
||||
// older versions error out — the comment in restore.go explains why).
|
||||
func TestRunRestoreNewDirArgvShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
restore)
|
||||
echo "argv: $*" 1>&2
|
||||
echo '{"message_type":"summary","seconds_elapsed":0,"total_files":0,"files_restored":0,"total_bytes":0,"bytes_restored":0}'
|
||||
;;
|
||||
esac
|
||||
`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
if err := r.RunRestore(context.Background(), "job-r3", "abc",
|
||||
[]string{"/etc/foo"}, false, "/tmp/restore-out"); err != nil {
|
||||
t.Fatalf("RunRestore: %v", err)
|
||||
}
|
||||
|
||||
var argv string
|
||||
for _, e := range tx.envs {
|
||||
if e.Type == api.MsgLogStream {
|
||||
var p api.LogStreamLine
|
||||
_ = e.UnmarshalPayload(&p)
|
||||
if p.Stream == api.LogStderr && strings.HasPrefix(p.Payload, "argv:") {
|
||||
argv = p.Payload
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if argv == "" {
|
||||
t.Fatal("no argv echo")
|
||||
}
|
||||
if strings.Contains(argv, "--no-ownership") {
|
||||
t.Errorf("restic 0.16 doesn't accept --no-ownership; got argv=%q", argv)
|
||||
}
|
||||
if !strings.Contains(argv, "--target /tmp/restore-out") {
|
||||
t.Errorf("expected --target /tmp/restore-out; got argv=%q", argv)
|
||||
}
|
||||
if !strings.Contains(argv, "--include /etc/foo") {
|
||||
t.Errorf("expected --include /etc/foo; got argv=%q", argv)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunRestoreNewDirAutoCreatesTarget: a new-directory restore
|
||||
// should mkdir the requested target chain before invoking restic, so
|
||||
// operators don't have to pre-create the per-job subdir.
|
||||
func TestRunRestoreNewDirAutoCreatesTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
restore)
|
||||
echo '{"message_type":"summary","seconds_elapsed":0,"total_files":0,"files_restored":0,"total_bytes":0,"bytes_restored":0}'
|
||||
;;
|
||||
esac
|
||||
`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
|
||||
// Multi-level path the operator hasn't created yet.
|
||||
target := filepath.Join(t.TempDir(), "deep", "deeper", "deepest")
|
||||
if err := r.RunRestore(context.Background(), "job-rmkdir", "abc",
|
||||
[]string{"/etc/foo"}, false, target); err != nil {
|
||||
t.Fatalf("RunRestore: %v", err)
|
||||
}
|
||||
|
||||
if st, err := os.Stat(target); err != nil {
|
||||
t.Fatalf("expected target dir to exist: %v", err)
|
||||
} else if !st.IsDir() {
|
||||
t.Fatalf("expected directory, got %v", st.Mode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunDiffShipsLogLines: diff output is forwarded as log.stream.
|
||||
func TestRunDiffShipsLogLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
bin := setupScript(t, `
|
||||
case "$1" in
|
||||
diff)
|
||||
echo '{"message_type":"change","path":"/etc/nginx/nginx.conf","modifier":"M"}'
|
||||
echo '{"message_type":"statistics","added":{"files":0,"dirs":0}}'
|
||||
;;
|
||||
esac
|
||||
`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
if err := r.RunDiff(context.Background(), "job-d1", "snap-a", "snap-b"); err != nil {
|
||||
t.Fatalf("RunDiff: %v", err)
|
||||
}
|
||||
|
||||
startEnv := firstEnvOfType(t, tx.envs, api.MsgJobStarted)
|
||||
var startP api.JobStartedPayload
|
||||
_ = startEnv.UnmarshalPayload(&startP)
|
||||
if startP.Kind != api.JobDiff {
|
||||
t.Fatalf("kind: got %q want %q", startP.Kind, api.JobDiff)
|
||||
}
|
||||
finEnv := firstEnvOfType(t, tx.envs, api.MsgJobFinished)
|
||||
var finP api.JobFinishedPayload
|
||||
_ = finEnv.UnmarshalPayload(&finP)
|
||||
if finP.Status != api.JobSucceeded {
|
||||
t.Fatalf("status: %q", finP.Status)
|
||||
}
|
||||
// At least one log line should carry the change payload.
|
||||
var sawChange bool
|
||||
for _, e := range tx.envs {
|
||||
if e.Type != api.MsgLogStream {
|
||||
continue
|
||||
}
|
||||
var p api.LogStreamLine
|
||||
_ = e.UnmarshalPayload(&p)
|
||||
if strings.Contains(p.Payload, `"message_type":"change"`) {
|
||||
sawChange = true
|
||||
}
|
||||
}
|
||||
if !sawChange {
|
||||
t.Fatal("never saw a change log line in diff output")
|
||||
}
|
||||
}
|
||||
+20
-177
@@ -26,22 +26,10 @@ type Sender interface {
|
||||
// from the agent's config file (server-pushed config.update payloads
|
||||
// override these in memory).
|
||||
type Config struct {
|
||||
ResticBin string
|
||||
ResticVersion string // e.g. "0.17.1" — empty if unknown
|
||||
RepoURL string
|
||||
RepoUsername string
|
||||
RepoPassword string
|
||||
|
||||
// SupportsRestoreNoOwnership comes from a startup probe of
|
||||
// `restic restore --help`; gates the new-dir-restore flag without
|
||||
// relying on version sniffing.
|
||||
SupportsRestoreNoOwnership bool
|
||||
|
||||
// Bandwidth caps in KB/s applied to every restic invocation.
|
||||
// <=0 means "no cap". Per-job override: callers that build a
|
||||
// runner per-dispatch can pass the override value here directly.
|
||||
LimitUploadKBps int
|
||||
LimitDownloadKBps int
|
||||
ResticBin string
|
||||
RepoURL string
|
||||
RepoUsername string
|
||||
RepoPassword string
|
||||
}
|
||||
|
||||
// Runner owns the restic invocations.
|
||||
@@ -66,14 +54,10 @@ func New(cfg Config, tx Sender, progressMinPeriod time.Duration) *Runner {
|
||||
// resticEnv builds the shared restic.Env from r.cfg.
|
||||
func (r *Runner) resticEnv() restic.Env {
|
||||
return restic.Env{
|
||||
Bin: r.cfg.ResticBin,
|
||||
Version: r.cfg.ResticVersion,
|
||||
RepoURL: r.cfg.RepoURL,
|
||||
RepoUsername: r.cfg.RepoUsername,
|
||||
RepoPassword: r.cfg.RepoPassword,
|
||||
SupportsRestoreNoOwnership: r.cfg.SupportsRestoreNoOwnership,
|
||||
LimitUploadKBps: r.cfg.LimitUploadKBps,
|
||||
LimitDownloadKBps: r.cfg.LimitDownloadKBps,
|
||||
Bin: r.cfg.ResticBin,
|
||||
RepoURL: r.cfg.RepoURL,
|
||||
RepoUsername: r.cfg.RepoUsername,
|
||||
RepoPassword: r.cfg.RepoPassword,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +87,8 @@ func (r *Runner) streamHandler(jobID string, seq *atomic.Int64) restic.LineHandl
|
||||
}
|
||||
|
||||
// sendFinished ships a job.finished envelope. err==nil → succeeded;
|
||||
// otherwise failed (or canceled if ctx was canceled — operator
|
||||
// hit the Cancel button or the agent is shutting down).
|
||||
// statsBlob is forwarded as JobFinishedPayload.Stats.
|
||||
func (r *Runner) sendFinished(ctx context.Context, jobID string, finishedAt time.Time, err error, statsBlob json.RawMessage) {
|
||||
// otherwise failed. statsBlob is forwarded as JobFinishedPayload.Stats.
|
||||
func (r *Runner) sendFinished(jobID string, finishedAt time.Time, err error, statsBlob json.RawMessage) {
|
||||
status := api.JobSucceeded
|
||||
exit := 0
|
||||
errMsg := ""
|
||||
@@ -114,16 +96,6 @@ func (r *Runner) sendFinished(ctx context.Context, jobID string, finishedAt time
|
||||
status = api.JobFailed
|
||||
exit = -1
|
||||
errMsg = err.Error()
|
||||
// If the context was canceled, the failure is operator-driven
|
||||
// (or shutdown). Surface as JobCancelled so the UI shows a
|
||||
// neutral "canceled" state rather than a red "failed" one.
|
||||
// exec.CommandContext returns the process's exit error on
|
||||
// ctx-cancel, which we'd otherwise rebadge as failed.
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
status = api.JobCancelled
|
||||
exit = 130 // POSIX convention for SIGINT/SIGTERM-killed
|
||||
errMsg = "" // no need to surface the underlying restic error
|
||||
}
|
||||
}
|
||||
finEnv, _ := api.Marshal(api.MsgJobFinished, jobID, api.JobFinishedPayload{
|
||||
JobID: jobID,
|
||||
@@ -136,35 +108,16 @@ func (r *Runner) sendFinished(ctx context.Context, jobID string, finishedAt time
|
||||
_ = r.tx.Send(finEnv)
|
||||
}
|
||||
|
||||
// BackupHooks bundles the optional pre/post shell snippets that fire
|
||||
// around a backup. Empty fields skip that phase. Resolved server-side
|
||||
// (group → host default) before dispatch; the agent just executes
|
||||
// whatever arrives in the payload.
|
||||
type BackupHooks struct {
|
||||
Pre string
|
||||
Post string
|
||||
}
|
||||
|
||||
// RunBackup executes a backup job and reports back via the sender.
|
||||
// Returns nil on a clean (or "incomplete-but-snapshot-created") finish.
|
||||
func (r *Runner) RunBackup(ctx context.Context, jobID string, paths, excludes, tags []string, hooks BackupHooks) error {
|
||||
func (r *Runner) RunBackup(ctx context.Context, jobID string, paths, excludes, tags []string) error {
|
||||
startedAt := time.Now().UTC()
|
||||
r.sendStarted(jobID, api.JobBackup, startedAt)
|
||||
|
||||
var seq atomic.Int64
|
||||
|
||||
// pre_hook: non-zero exit aborts the backup. The job is recorded
|
||||
// as failed with the hook's error and restic never runs.
|
||||
if hooks.Pre != "" {
|
||||
if err := r.runHook(ctx, jobID, "pre", hooks.Pre, "", &seq); err != nil {
|
||||
finishedAt := time.Now().UTC()
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
return fmt.Errorf("pre_hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
env := r.resticEnv()
|
||||
lastProgress := time.Time{} // zero time → first status event always emits
|
||||
|
||||
var seq atomic.Int64
|
||||
lastProgress := time.Now()
|
||||
|
||||
handle := func(stream string, line string, ev any) {
|
||||
// Throttled progress events come from restic's `status` JSON.
|
||||
@@ -212,21 +165,7 @@ func (r *Runner) RunBackup(ctx context.Context, jobID string, paths, excludes, t
|
||||
if summary != nil {
|
||||
statsBlob, _ = json.Marshal(summary)
|
||||
}
|
||||
|
||||
// post_hook: always runs regardless of backup outcome. Receives
|
||||
// RM_JOB_STATUS=succeeded|failed in env. Non-zero exit is logged
|
||||
// but does not change the recorded job status.
|
||||
if hooks.Post != "" {
|
||||
status := "succeeded"
|
||||
if err != nil {
|
||||
status = "failed"
|
||||
}
|
||||
if perr := r.runHook(ctx, jobID, "post", hooks.Post, status, &seq); perr != nil {
|
||||
slog.Warn("runner: post_hook exited non-zero", "job_id", jobID, "err", perr)
|
||||
}
|
||||
}
|
||||
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, statsBlob)
|
||||
r.sendFinished(jobID, finishedAt, err, statsBlob)
|
||||
|
||||
// On a successful backup, refresh the server's snapshot projection.
|
||||
// We do this *after* job.finished so the UI sees the job land first;
|
||||
@@ -260,7 +199,7 @@ func (r *Runner) RunInit(ctx context.Context, jobID string) error {
|
||||
var seq atomic.Int64
|
||||
err := env.RunInit(ctx, r.streamHandler(jobID, &seq))
|
||||
finishedAt := time.Now().UTC()
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
r.sendFinished(jobID, finishedAt, err, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner init: %w", err)
|
||||
}
|
||||
@@ -282,7 +221,7 @@ func (r *Runner) RunForget(ctx context.Context, jobID string, groups []restic.Fo
|
||||
var seq atomic.Int64
|
||||
err := env.RunForget(ctx, groups, r.streamHandler(jobID, &seq))
|
||||
finishedAt := time.Now().UTC()
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
r.sendFinished(jobID, finishedAt, err, nil)
|
||||
|
||||
// Refresh the server's snapshot projection — forget rewrites the
|
||||
// index so the host's snapshot list almost certainly shrunk.
|
||||
@@ -320,7 +259,7 @@ func (r *Runner) RunPrune(ctx context.Context, jobID string) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
r.sendFinished(jobID, finishedAt, err, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner prune: %w", err)
|
||||
@@ -359,7 +298,7 @@ func (r *Runner) RunCheck(ctx context.Context, jobID string, subsetPct int) erro
|
||||
slog.Warn("runner: stats.report after check failed", "job_id", jobID, "err", rerr)
|
||||
}
|
||||
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
r.sendFinished(jobID, finishedAt, err, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner check: %w", err)
|
||||
@@ -367,102 +306,6 @@ func (r *Runner) RunCheck(ctx context.Context, jobID string, subsetPct int) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunRestore executes a restic restore job and reports back via the
|
||||
// sender. paths is the operator-selected file/dir list to restore.
|
||||
// inPlace=true preserves uid/gid/mode and writes at "/"; inPlace=false
|
||||
// writes at targetDir with --no-ownership.
|
||||
//
|
||||
// Status events from restic are throttled into job.progress in the
|
||||
// same shape as backup; raw status lines are dropped from log.stream
|
||||
// (they would drown the log on a fast restore — the progress widget
|
||||
// already covers them).
|
||||
func (r *Runner) RunRestore(ctx context.Context, jobID, snapshotID string, paths []string, inPlace bool, targetDir string) error {
|
||||
startedAt := time.Now().UTC()
|
||||
r.sendStarted(jobID, api.JobRestore, startedAt)
|
||||
|
||||
env := r.resticEnv()
|
||||
var seq atomic.Int64
|
||||
lastProgress := time.Time{} // zero time → first status event always emits
|
||||
|
||||
handle := func(stream string, line string, ev any) {
|
||||
status, isStatus := ev.(restic.RestoreStatus)
|
||||
if !isStatus {
|
||||
now := time.Now().UTC()
|
||||
logEnv, _ := api.Marshal(api.MsgLogStream, "", api.LogStreamLine{
|
||||
JobID: jobID,
|
||||
Seq: seq.Add(1),
|
||||
TS: now,
|
||||
Stream: api.LogStream(stream),
|
||||
Payload: line,
|
||||
})
|
||||
_ = r.tx.Send(logEnv)
|
||||
}
|
||||
if isStatus {
|
||||
if time.Since(lastProgress) < r.progressMinPeriod {
|
||||
return
|
||||
}
|
||||
lastProgress = time.Now()
|
||||
progEnv, _ := api.Marshal(api.MsgJobProgress, jobID, api.JobProgressPayload{
|
||||
JobID: jobID,
|
||||
PercentDone: status.PercentDone,
|
||||
FilesDone: status.FilesRestored,
|
||||
TotalFiles: status.TotalFiles,
|
||||
BytesDone: status.BytesRestored,
|
||||
TotalBytes: status.TotalBytes,
|
||||
ETASeconds: estimateETA(status.BytesRestored, status.TotalBytes, status.SecondsElapsed),
|
||||
ThroughputBps: throughput(status.BytesRestored, status.SecondsElapsed),
|
||||
})
|
||||
_ = r.tx.Send(progEnv)
|
||||
}
|
||||
}
|
||||
|
||||
summary, err := env.RunRestore(ctx, snapshotID, paths, inPlace, targetDir, handle)
|
||||
finishedAt := time.Now().UTC()
|
||||
|
||||
var statsBlob json.RawMessage
|
||||
if summary != nil {
|
||||
statsBlob, _ = json.Marshal(summary)
|
||||
}
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, statsBlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner restore: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// estimateETA computes an ETA in seconds based on current bytes
|
||||
// progress + elapsed seconds. Restic restore's --json doesn't emit an
|
||||
// ETA field of its own (unlike backup), so we approximate by linear
|
||||
// extrapolation. Returns 0 when we don't have enough data.
|
||||
func estimateETA(bytesDone, totalBytes, secondsElapsed int64) int64 {
|
||||
if bytesDone <= 0 || totalBytes <= 0 || secondsElapsed <= 0 || bytesDone >= totalBytes {
|
||||
return 0
|
||||
}
|
||||
rate := float64(bytesDone) / float64(secondsElapsed)
|
||||
if rate <= 0 {
|
||||
return 0
|
||||
}
|
||||
return int64(float64(totalBytes-bytesDone) / rate)
|
||||
}
|
||||
|
||||
// RunDiff executes `restic diff --json <a> <b>` and forwards output
|
||||
// as log.stream lines. No snapshot-list refresh, no stats update —
|
||||
// diff is purely informational.
|
||||
func (r *Runner) RunDiff(ctx context.Context, jobID, snapshotA, snapshotB string) error {
|
||||
startedAt := time.Now().UTC()
|
||||
r.sendStarted(jobID, api.JobDiff, startedAt)
|
||||
|
||||
env := r.resticEnv()
|
||||
var seq atomic.Int64
|
||||
err := env.RunDiff(ctx, snapshotA, snapshotB, r.streamHandler(jobID, &seq))
|
||||
finishedAt := time.Now().UTC()
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner diff: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunUnlock executes a `restic unlock` job. On success it ships a
|
||||
// repo.stats envelope with LockPresent=false so the UI banner clears.
|
||||
func (r *Runner) RunUnlock(ctx context.Context, jobID string) error {
|
||||
@@ -482,7 +325,7 @@ func (r *Runner) RunUnlock(ctx context.Context, jobID string) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.sendFinished(ctx, jobID, finishedAt, err, nil)
|
||||
r.sendFinished(jobID, finishedAt, err, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("runner unlock: %w", err)
|
||||
|
||||
@@ -4,42 +4,20 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/restic"
|
||||
)
|
||||
|
||||
// fakeSender collects sent envelopes for assertions. Lock-protected
|
||||
// because the runner's pumpStdout / pumpStderr goroutines call Send
|
||||
// concurrently — without the mutex, -race in CI flags every test
|
||||
// that exercises a Run* method with both pumps active.
|
||||
type fakeSender struct {
|
||||
mu sync.Mutex
|
||||
envs []api.Envelope
|
||||
}
|
||||
// fakeSender collects sent envelopes for assertions.
|
||||
type fakeSender struct{ envs []api.Envelope }
|
||||
|
||||
func (s *fakeSender) Send(e api.Envelope) error {
|
||||
s.mu.Lock()
|
||||
s.envs = append(s.envs, e)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// snapshot returns a copy of the captured envelopes safe to read
|
||||
// without holding the lock. Tests use this when iterating envs while
|
||||
// other goroutines may still be writing — though in practice all
|
||||
// runner Run* methods join their pumps before returning, so callers
|
||||
// can also read .envs directly post-return.
|
||||
func (s *fakeSender) snapshot() []api.Envelope {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]api.Envelope, len(s.envs))
|
||||
copy(out, s.envs)
|
||||
return out
|
||||
}
|
||||
|
||||
// setupScript writes a shell script (without shebang) to a temp dir,
|
||||
// names it "restic", makes it executable, and returns the path.
|
||||
//
|
||||
@@ -342,7 +320,7 @@ esac
|
||||
// still produces job.started and job.finished envelopes.
|
||||
func TestRunInitShipsStartedAndFinished(t *testing.T) {
|
||||
t.Parallel()
|
||||
bin := setupScript(t, `echo "initialised repository"`)
|
||||
bin := setupScript(t, `echo "initialized repository"`)
|
||||
tx := &fakeSender{}
|
||||
r := New(Config{ResticBin: bin}, tx, 0)
|
||||
if err := r.RunInit(context.Background(), "job-init"); err != nil {
|
||||
|
||||
@@ -110,7 +110,7 @@ func (s *Scheduler) Apply(payload api.ScheduleSetPayload, tx Sender) {
|
||||
"received", len(payload.Schedules), "active", added)
|
||||
|
||||
// Ack outside the lock — Send() shouldn't take long, but holding
|
||||
// s.mu across an external call would needlessly serialise other
|
||||
// s.mu across an external call would needlessly serialize other
|
||||
// callers (e.g. a future Status() inspection from the UI).
|
||||
ackEnv, err := api.Marshal(api.MsgScheduleAck, "", api.ScheduleAckPayload{
|
||||
Version: payload.Version,
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
// additionalData binds ciphertexts to the agent-secrets context, so a
|
||||
// blob lifted from one role's file can't be replayed into another's
|
||||
// row in some unrelated table that uses the same key. (Defence in
|
||||
// row in some unrelated table that uses the same key. (Defense in
|
||||
// depth — the key is per-host today, but cheap to be careful.)
|
||||
const additionalData = "rm-agent-repo-creds-v1"
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
// install_windows.go — thin wrappers around the Service Control
|
||||
// Manager via golang.org/x/sys/windows/svc/mgr. Used by the agent's
|
||||
// `install` / `uninstall` / `start` / `stop` subcommands.
|
||||
//
|
||||
// UNTESTED in CI. Mirrors the canonical example shape; if you need
|
||||
// to extend this, prefer copying from x/sys/windows/svc/example
|
||||
// over inventing new patterns.
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// Install registers the service with the SCM, pointing it at the
|
||||
// currently-running binary. The service starts on every boot and
|
||||
// runs as LocalSystem (default).
|
||||
func Install() error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("install: locate executable: %w", err)
|
||||
}
|
||||
exe, err = filepath.Abs(exe)
|
||||
if err != nil {
|
||||
return fmt.Errorf("install: absolutise path: %w", err)
|
||||
}
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("install: connect SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
if existing, err := m.OpenService(ServiceName); err == nil {
|
||||
_ = existing.Close()
|
||||
return fmt.Errorf("service %q already installed; uninstall first", ServiceName)
|
||||
}
|
||||
s, err := m.CreateService(ServiceName, exe, mgr.Config{
|
||||
StartType: mgr.StartAutomatic,
|
||||
DisplayName: "Restic-manager agent",
|
||||
Description: "Backs up this host on the schedule the central restic-manager dictates.",
|
||||
}, "run")
|
||||
if err != nil {
|
||||
return fmt.Errorf("install: create service: %w", err)
|
||||
}
|
||||
defer s.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall removes the service from the SCM. Caller is expected to
|
||||
// stop the service first; this returns the SCM's error if it's
|
||||
// still running.
|
||||
func Uninstall() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("uninstall: connect SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(ServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("uninstall: open service: %w", err)
|
||||
}
|
||||
defer s.Close()
|
||||
if err := s.Delete(); err != nil {
|
||||
return fmt.Errorf("uninstall: delete service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start asks the SCM to start the installed service. No-op if it's
|
||||
// already running (the SCM returns an error which we surface).
|
||||
func Start() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(ServiceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
// Stop sends a stop control to the service.
|
||||
func Stop() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(ServiceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
_, err = s.Control(0x00000001) // SERVICE_CONTROL_STOP
|
||||
return err
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
// service_other.go — non-Windows fallback for the service package.
|
||||
// Linux uses systemd to wrap the agent; the binary itself just runs
|
||||
// in the foreground. Run() therefore just executes the agent loop
|
||||
// and returns. install/uninstall sub-commands return a clear error
|
||||
// directing the operator at the install.sh + systemd unit shipped
|
||||
// in deploy/install/.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// AgentRun is the function-pointer shape main passes in. Same shape
|
||||
// as the Windows variant so the call site is portable.
|
||||
type AgentRun func(ctx context.Context) error
|
||||
|
||||
// Run executes the agent loop in the foreground; on Unix the
|
||||
// systemd unit (or whatever runs us) supplies the lifecycle.
|
||||
func Run(agentRun AgentRun) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
return agentRun(ctx)
|
||||
}
|
||||
|
||||
// Install registers the agent as a service. Windows-only; on Unix
|
||||
// the systemd unit covers this — returns an error pointing there.
|
||||
func Install() error { return errUnsupported("install") }
|
||||
|
||||
// Uninstall is the inverse of Install. Windows-only.
|
||||
func Uninstall() error { return errUnsupported("uninstall") }
|
||||
|
||||
// Start asks the OS service manager to start the installed service.
|
||||
// Windows-only.
|
||||
func Start() error { return errUnsupported("start") }
|
||||
|
||||
// Stop sends a stop signal to the installed service. Windows-only.
|
||||
func Stop() error { return errUnsupported("stop") }
|
||||
|
||||
func errUnsupported(verb string) error {
|
||||
return errors.New("service " + verb + " is Windows-only; use the systemd unit on Linux")
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
// service_windows.go — Service Control Manager integration for the
|
||||
// agent on Windows (P2-16). Implements the svc.Handler interface so
|
||||
// `restic-manager-agent run` works under both interactive and SCM
|
||||
// contexts. install/uninstall live in install_windows.go.
|
||||
//
|
||||
// UNTESTED on Windows in this repo's CI (the runners are Linux).
|
||||
// The shape mirrors the canonical example in
|
||||
// golang.org/x/sys/windows/svc/example. Treat any deviation from
|
||||
// that example as suspicious.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/windows/svc"
|
||||
)
|
||||
|
||||
// ServiceName is the SCM identifier for the agent service.
|
||||
const ServiceName = "restic-manager-agent"
|
||||
|
||||
// AgentRun is the function the service handler calls to start the
|
||||
// agent's main loop. Pass cmd/agent's run-loop entry point at the
|
||||
// call site so this package stays free of cross-cmd imports.
|
||||
type AgentRun func(ctx context.Context) error
|
||||
|
||||
// Run delegates to the SCM dispatcher when running under Windows
|
||||
// service control, otherwise runs the agent loop in the foreground
|
||||
// (for `restic-manager-agent run` from a console, e.g. while
|
||||
// debugging on a developer's box).
|
||||
func Run(agentRun AgentRun) error {
|
||||
isService, err := svc.IsWindowsService()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
return agentRun(ctx)
|
||||
}
|
||||
return svc.Run(ServiceName, &handler{run: agentRun})
|
||||
}
|
||||
|
||||
// handler implements svc.Handler. Execute is called once when the
|
||||
// service is started. We spawn the agent loop in a goroutine and
|
||||
// listen for SCM Stop / Shutdown notifications, cancelling the
|
||||
// context to wind down cleanly.
|
||||
type handler struct {
|
||||
run AgentRun
|
||||
}
|
||||
|
||||
func (h *handler) Execute(_ []string, req <-chan svc.ChangeRequest, status chan<- svc.Status) (bool, uint32) {
|
||||
const accepted = svc.AcceptStop | svc.AcceptShutdown
|
||||
status <- svc.Status{State: svc.StartPending}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
doneCh := make(chan error, 1)
|
||||
go func() {
|
||||
doneCh <- h.run(ctx)
|
||||
}()
|
||||
status <- svc.Status{State: svc.Running, Accepts: accepted}
|
||||
|
||||
for {
|
||||
select {
|
||||
case c := <-req:
|
||||
switch c.Cmd {
|
||||
case svc.Interrogate:
|
||||
status <- c.CurrentStatus
|
||||
case svc.Stop, svc.Shutdown:
|
||||
slog.Info("svc: stop requested")
|
||||
cancel()
|
||||
status <- svc.Status{State: svc.StopPending}
|
||||
if err := <-doneCh; err != nil && !errors.Is(err, context.Canceled) {
|
||||
slog.Warn("svc: agent loop exited with error", "err", err)
|
||||
return false, 1
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
case err := <-doneCh:
|
||||
// Agent loop exited on its own — uncommon (only via signal
|
||||
// or fatal error). Surface as an SCM stop.
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
slog.Warn("svc: agent loop exited unexpectedly", "err", err)
|
||||
return false, 1
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -76,5 +76,5 @@ func detectResticVersion(ctx context.Context, override string) (string, error) {
|
||||
if len(parts) >= 2 && parts[0] == "restic" {
|
||||
return parts[1], nil
|
||||
}
|
||||
return "", fmt.Errorf("sysinfo: unrecognised restic version output: %q", first)
|
||||
return "", fmt.Errorf("sysinfo: unrecognized restic version output: %q", first)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type Config struct {
|
||||
// Sender is what handlers use to push agent → server messages
|
||||
// (job.progress, job.finished, log.stream, command.result, …).
|
||||
// Returned by the WS client to the dispatch handler. Write operations
|
||||
// serialise behind a single mutex on the conn; concurrent calls are
|
||||
// serialize behind a single mutex on the conn; concurrent calls are
|
||||
// safe.
|
||||
type Sender interface {
|
||||
Send(env api.Envelope) error
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
// Package alert evaluates the hardcoded rule set and persists raises
|
||||
// / acknowledges / resolves. Three event sources feed it:
|
||||
// - JobFinishedEvent — pushed when a job lands a terminal state
|
||||
// (the existing MarkJobFinished site)
|
||||
// - HostOfflineEvent / HostOnlineEvent — pushed by the offline
|
||||
// sweeper and by the ws hello handler
|
||||
// - 60s ticker (internal) — drives stale-schedule + auto-resolve
|
||||
//
|
||||
// All output goes through store.RaiseOrTouch / Acknowledge / Resolve
|
||||
// and the notification.Hub. The engine is one goroutine started at
|
||||
// boot; non-blocking sends from hot paths.
|
||||
package alert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// JobFinishedEvent carries everything the engine needs to evaluate
|
||||
// the failed-X rules. Pushed via Engine.NotifyJobFinished from the
|
||||
// MarkJobFinished site.
|
||||
type JobFinishedEvent struct {
|
||||
HostID string
|
||||
JobID string
|
||||
Kind string // backup | forget | prune | check | unlock | restore | diff
|
||||
Status string // succeeded | failed | cancelled
|
||||
SourceGroupID string // dedup key for backup/forget/prune/check; empty otherwise
|
||||
When time.Time
|
||||
}
|
||||
|
||||
// Engine evaluates hardcoded alert rules and dispatches via notification.Hub.
|
||||
type Engine struct {
|
||||
store *store.Store
|
||||
hub *notification.Hub
|
||||
|
||||
jobs chan JobFinishedEvent
|
||||
hostDown chan string // host_id
|
||||
hostUp chan string
|
||||
|
||||
// agentOfflineFloor is the duration a host must be offline before
|
||||
// we raise. Configurable for tests; default 15m.
|
||||
agentOfflineFloor time.Duration
|
||||
tickPeriod time.Duration
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewEngine builds the engine. agentOfflineFloor + tickPeriod default
|
||||
// to 15min and 60s respectively when zero.
|
||||
func NewEngine(st *store.Store, hub *notification.Hub) *Engine {
|
||||
return &Engine{
|
||||
store: st,
|
||||
hub: hub,
|
||||
jobs: make(chan JobFinishedEvent, 32),
|
||||
hostDown: make(chan string, 32),
|
||||
hostUp: make(chan string, 32),
|
||||
agentOfflineFloor: 15 * time.Minute,
|
||||
tickPeriod: 60 * time.Second,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run drives the event loop. Returns when ctx is done. Blocks; call in
|
||||
// its own goroutine.
|
||||
func (e *Engine) Run(ctx context.Context) {
|
||||
t := time.NewTicker(e.tickPeriod)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
e.closeOnce.Do(func() { close(e.done) })
|
||||
return
|
||||
case ev := <-e.jobs:
|
||||
e.handleJobFinished(ctx, ev)
|
||||
case hostID := <-e.hostDown:
|
||||
e.handleHostOffline(ctx, hostID)
|
||||
case hostID := <-e.hostUp:
|
||||
e.handleHostOnline(ctx, hostID)
|
||||
case now := <-t.C:
|
||||
e.tick(ctx, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyJobFinished is the hot-path hook called from MarkJobFinished's
|
||||
// caller (ws.handler.dispatchAgentMessage). Non-blocking: drops on a
|
||||
// full channel with a slog warning.
|
||||
func (e *Engine) NotifyJobFinished(ev JobFinishedEvent) {
|
||||
select {
|
||||
case e.jobs <- ev:
|
||||
default:
|
||||
slog.Warn("alert: jobs channel full; dropping event", "kind", ev.Kind, "host_id", ev.HostID)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyHostOffline notifies the engine that a host is offline.
|
||||
func (e *Engine) NotifyHostOffline(hostID string) {
|
||||
select {
|
||||
case e.hostDown <- hostID:
|
||||
default:
|
||||
slog.Warn("alert: hostDown channel full; dropping", "host_id", hostID)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyHostOnline notifies the engine that a host is online.
|
||||
func (e *Engine) NotifyHostOnline(hostID string) {
|
||||
select {
|
||||
case e.hostUp <- hostID:
|
||||
default:
|
||||
slog.Warn("alert: hostUp channel full; dropping", "host_id", hostID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) handleJobFinished(ctx context.Context, ev JobFinishedEvent) {
|
||||
// Determine which kind/severity pair this job maps to. Jobs not
|
||||
// listed here (init, unlock, restore, diff) produce no alerts in v1.
|
||||
var kind, severity string
|
||||
switch ev.Kind {
|
||||
case "backup":
|
||||
kind, severity = KindBackupFailed, "warning"
|
||||
case "forget":
|
||||
kind, severity = KindForgetFailed, "warning"
|
||||
case "prune":
|
||||
kind, severity = KindPruneFailed, "warning"
|
||||
case "check":
|
||||
kind, severity = KindCheckFailed, "critical"
|
||||
default:
|
||||
return
|
||||
}
|
||||
// dedupKey scopes the alert to a specific subject. For backups it's
|
||||
// the source-group id (each group = its own restic run = its own
|
||||
// failure surface). forget/prune/check are repo-scoped — leave the
|
||||
// key empty so we get one alert per host per kind, matching the
|
||||
// "is this repo healthy?" mental model.
|
||||
dedupKey := ""
|
||||
if ev.Kind == "backup" {
|
||||
dedupKey = ev.SourceGroupID
|
||||
}
|
||||
switch ev.Status {
|
||||
case "failed":
|
||||
e.raiseAndNotify(ctx, ev.HostID, kind, dedupKey, severity,
|
||||
fmt.Sprintf("%s job %s failed", ev.Kind, ev.JobID), ev.When)
|
||||
case "succeeded":
|
||||
e.resolveAndNotify(ctx, ev.HostID, kind, dedupKey, ev.When)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) handleHostOffline(ctx context.Context, hostID string) {
|
||||
host, err := e.store.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Apply the 15-min floor — raise only when last_seen_at is older
|
||||
// than agentOfflineFloor. A nil last_seen_at (host enrolled but
|
||||
// never connected) is treated as "now" so we don't raise
|
||||
// immediately on enrolment.
|
||||
if host.LastSeenAt == nil {
|
||||
return
|
||||
}
|
||||
if time.Since(*host.LastSeenAt) < e.agentOfflineFloor {
|
||||
return
|
||||
}
|
||||
e.raiseAndNotify(ctx, hostID, KindAgentOffline, "", "warning",
|
||||
fmt.Sprintf("Agent offline for %s (threshold %s)",
|
||||
roundDur(time.Since(*host.LastSeenAt)), e.agentOfflineFloor),
|
||||
time.Now().UTC())
|
||||
}
|
||||
|
||||
func (e *Engine) handleHostOnline(ctx context.Context, hostID string) {
|
||||
e.resolveAndNotify(ctx, hostID, KindAgentOffline, "", time.Now().UTC())
|
||||
}
|
||||
|
||||
// tick is the 60-second sweep. Responsibilities:
|
||||
// 1. Re-evaluate agent_offline for every offline host that may have
|
||||
// crossed the floor between explicit events.
|
||||
// 2. Stale-schedule detection — declared in the spec but intentionally
|
||||
// left as a no-op in v1. The precise "expected to have fired but
|
||||
// didn't" trigger requires a store helper that lands in a later
|
||||
// task. The KindStaleSchedule constant is exported so UI code can
|
||||
// reference the tag string today.
|
||||
func (e *Engine) tick(ctx context.Context, now time.Time) {
|
||||
// User-management cleanup piggy-backed here for now. Setup tokens
|
||||
// have a 1h expiry; the alert engine tick is the cheapest existing
|
||||
// 60s loop. If more housekeeping queries appear, extract a
|
||||
// dedicated maintenance loop.
|
||||
if _, err := e.store.CleanupExpiredSetupTokens(ctx, now); err != nil {
|
||||
slog.Warn("alert: cleanup expired setup tokens", "err", err)
|
||||
}
|
||||
if _, err := e.store.CleanupExpiredOIDCState(ctx, now.Add(-5*time.Minute)); err != nil {
|
||||
slog.Warn("alert: cleanup expired oidc state", "err", err)
|
||||
}
|
||||
|
||||
hosts, err := e.store.ListHosts(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("alert: tick list hosts", "err", err)
|
||||
return
|
||||
}
|
||||
for _, h := range hosts {
|
||||
if h.Status != "offline" || h.LastSeenAt == nil {
|
||||
continue
|
||||
}
|
||||
if now.Sub(*h.LastSeenAt) >= e.agentOfflineFloor {
|
||||
e.raiseAndNotify(ctx, h.ID, KindAgentOffline, "", "warning",
|
||||
fmt.Sprintf("Agent offline for %s (threshold %s)",
|
||||
roundDur(now.Sub(*h.LastSeenAt)), e.agentOfflineFloor), now)
|
||||
}
|
||||
}
|
||||
// Stale-schedule sweep — no-op in v1. See KindStaleSchedule doc comment.
|
||||
}
|
||||
|
||||
// roundDur returns a human-readable duration string, rounding to the
|
||||
// nearest minute. Durations under a minute are reported as "less than
|
||||
// a minute".
|
||||
func roundDur(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return "less than a minute"
|
||||
}
|
||||
return d.Round(time.Minute).String()
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
package alert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Alert kind constants — keep in lockstep with the engine logic and
|
||||
// the UI tag-colour table.
|
||||
const (
|
||||
// KindBackupFailed is raised when a backup job finishes with
|
||||
// status "failed" and resolved on next backup success.
|
||||
KindBackupFailed = "backup_failed"
|
||||
|
||||
// KindForgetFailed mirrors KindBackupFailed for forget jobs.
|
||||
KindForgetFailed = "forget_failed"
|
||||
|
||||
// KindPruneFailed mirrors KindBackupFailed for prune jobs.
|
||||
KindPruneFailed = "prune_failed"
|
||||
|
||||
// KindCheckFailed is raised at "critical" severity (repository
|
||||
// integrity is at risk) when a check job fails.
|
||||
KindCheckFailed = "check_failed"
|
||||
|
||||
// KindStaleSchedule is declared for completeness but intentionally
|
||||
// left as a no-op in v1. The precise "expected to have fired but
|
||||
// didn't" logic requires a store helper that lands in a follow-up
|
||||
// task. Ask the team before implementing.
|
||||
KindStaleSchedule = "stale_schedule"
|
||||
|
||||
// KindAgentOffline is raised when a host's last_seen_at is older
|
||||
// than the 15-minute floor and resolved when the host reconnects.
|
||||
KindAgentOffline = "agent_offline"
|
||||
)
|
||||
|
||||
// raiseAndNotify is the standard raise pattern: store.RaiseOrTouch
|
||||
// deduplicates, and notification.Hub.Dispatch fires only on the first
|
||||
// raise (didRaise=true). Subsequent occurrences of the same open alert
|
||||
// are "touched" (last_seen_at bumped) without a second notification.
|
||||
func (e *Engine) raiseAndNotify(ctx context.Context, hostID, kind, dedupKey, severity, message string, when time.Time) {
|
||||
id, didRaise, err := e.store.RaiseOrTouch(ctx, hostID, kind, dedupKey, severity, message, when)
|
||||
if err != nil {
|
||||
slog.Warn("alert: raise", "kind", kind, "host_id", hostID, "dedup_key", dedupKey, "err", err)
|
||||
return
|
||||
}
|
||||
if !didRaise {
|
||||
return
|
||||
}
|
||||
host, err := e.store.GetHost(ctx, hostID)
|
||||
hostName := hostID
|
||||
if err == nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
go e.hub.Dispatch(ctx, notification.Payload{
|
||||
Event: notification.EventRaised,
|
||||
AlertID: id,
|
||||
Severity: severity,
|
||||
Kind: kind,
|
||||
HostID: hostID,
|
||||
HostName: hostName,
|
||||
Message: message,
|
||||
RaisedAt: when,
|
||||
})
|
||||
}
|
||||
|
||||
// Acknowledge updates the alert row and fans out alert.acknowledged to
|
||||
// every enabled channel. Best-effort: store errors are logged but the
|
||||
// dispatch still fires only when the store update succeeds.
|
||||
func (e *Engine) Acknowledge(ctx context.Context, alertID, userID string, when time.Time) error {
|
||||
if err := e.store.Acknowledge(ctx, alertID, userID, when); err != nil {
|
||||
return err
|
||||
}
|
||||
a, lerr := e.store.GetAlert(ctx, alertID)
|
||||
if lerr != nil || a == nil {
|
||||
// Acknowledge already succeeded; dispatch is best-effort.
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
p := alertPayload(ctx, e.store, notification.EventAcknowledged, a)
|
||||
go e.hub.Dispatch(context.WithoutCancel(ctx), p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve marks the alert resolved and fans out alert.resolved.
|
||||
func (e *Engine) Resolve(ctx context.Context, alertID string, when time.Time) error {
|
||||
a, _ := e.store.GetAlert(ctx, alertID)
|
||||
if err := e.store.Resolve(ctx, alertID, when); err != nil {
|
||||
return err
|
||||
}
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
p := alertPayload(ctx, e.store, notification.EventResolved, a)
|
||||
go e.hub.Dispatch(context.WithoutCancel(ctx), p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// alertPayload builds a Payload from a stored Alert, looking up the host
|
||||
// name when HostID is set.
|
||||
func alertPayload(ctx context.Context, st *store.Store, ev notification.Event, a *store.Alert) notification.Payload {
|
||||
hostID, hostName := "", ""
|
||||
if a.HostID != nil {
|
||||
hostID = *a.HostID
|
||||
hostName = hostID
|
||||
if h, err := st.GetHost(ctx, hostID); err == nil && h != nil {
|
||||
hostName = h.Name
|
||||
}
|
||||
}
|
||||
return notification.Payload{
|
||||
Event: ev,
|
||||
AlertID: a.ID,
|
||||
Severity: a.Severity,
|
||||
Kind: a.Kind,
|
||||
HostID: hostID,
|
||||
HostName: hostName,
|
||||
Message: a.Message,
|
||||
RaisedAt: a.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAndNotify clears the open (or acknowledged) alert matching
|
||||
// (host_id, kind, dedup_key) via store.AutoResolve, then fires
|
||||
// alert.resolved for the row(s) actually closed. Best-effort —
|
||||
// errors are logged but do not propagate.
|
||||
func (e *Engine) resolveAndNotify(ctx context.Context, hostID, kind, dedupKey string, when time.Time) {
|
||||
open, err := e.store.ListAlerts(ctx, store.AlertFilter{
|
||||
Status: "open", HostID: hostID,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
openAcked, _ := e.store.ListAlerts(ctx, store.AlertFilter{
|
||||
Status: "acknowledged", HostID: hostID,
|
||||
})
|
||||
all := append(open, openAcked...)
|
||||
if err := e.store.AutoResolve(ctx, hostID, kind, dedupKey, when); err != nil {
|
||||
slog.Warn("alert: auto-resolve", "kind", kind, "host_id", hostID, "dedup_key", dedupKey, "err", err)
|
||||
return
|
||||
}
|
||||
host, _ := e.store.GetHost(ctx, hostID)
|
||||
hostName := hostID
|
||||
if host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
for _, a := range all {
|
||||
if a.Kind != kind || a.DedupKey != dedupKey {
|
||||
continue
|
||||
}
|
||||
go e.hub.Dispatch(ctx, notification.Payload{
|
||||
Event: notification.EventResolved,
|
||||
AlertID: a.ID,
|
||||
Severity: a.Severity,
|
||||
Kind: a.Kind,
|
||||
HostID: hostID,
|
||||
HostName: hostName,
|
||||
Message: fmt.Sprintf("Auto-resolved (%s)", kind),
|
||||
RaisedAt: when,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package alert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func setupEngine(t *testing.T) (*Engine, *store.Store, string) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, _ := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
_ = crypto.GenerateKeyFile(keyPath)
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
hub := notification.NewHub(st, aead, "https://rm.example")
|
||||
eng := NewEngine(st, hub)
|
||||
hostID := ulid.Make().String()
|
||||
if err := st.CreateHost(context.Background(), store.Host{
|
||||
ID: hostID, Name: "alfa-01", OS: "linux", Arch: "amd64",
|
||||
EnrolledAt: time.Now().UTC(),
|
||||
}, "deadbeef", ""); err != nil {
|
||||
t.Fatalf("create host: %v", err)
|
||||
}
|
||||
return eng, st, hostID
|
||||
}
|
||||
|
||||
func TestEngineBackupFailedRaisesThenResolves(t *testing.T) {
|
||||
t.Parallel()
|
||||
eng, st, hostID := setupEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
eng.handleJobFinished(ctx, JobFinishedEvent{
|
||||
HostID: hostID, JobID: "j1", Kind: "backup", Status: "failed",
|
||||
When: time.Now().UTC(),
|
||||
})
|
||||
open, _ := st.ListAlerts(ctx, store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 1 || open[0].Kind != KindBackupFailed {
|
||||
t.Fatalf("expected one backup_failed open; got %+v", open)
|
||||
}
|
||||
|
||||
// Second failed job should TOUCH (not raise a fresh row).
|
||||
eng.handleJobFinished(ctx, JobFinishedEvent{
|
||||
HostID: hostID, JobID: "j2", Kind: "backup", Status: "failed",
|
||||
When: time.Now().UTC().Add(time.Minute),
|
||||
})
|
||||
open, _ = st.ListAlerts(ctx, store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 1 {
|
||||
t.Fatalf("expected dedup to stay at 1 open; got %d", len(open))
|
||||
}
|
||||
|
||||
// Success auto-resolves.
|
||||
eng.handleJobFinished(ctx, JobFinishedEvent{
|
||||
HostID: hostID, JobID: "j3", Kind: "backup", Status: "succeeded",
|
||||
When: time.Now().UTC().Add(2 * time.Minute),
|
||||
})
|
||||
open, _ = st.ListAlerts(ctx, store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 0 {
|
||||
t.Fatalf("expected zero open after success; got %d", len(open))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineCheckFailedSeverityCritical(t *testing.T) {
|
||||
t.Parallel()
|
||||
eng, st, hostID := setupEngine(t)
|
||||
eng.handleJobFinished(context.Background(), JobFinishedEvent{
|
||||
HostID: hostID, Kind: "check", Status: "failed", When: time.Now().UTC(),
|
||||
})
|
||||
open, _ := st.ListAlerts(context.Background(),
|
||||
store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 1 || open[0].Severity != "critical" {
|
||||
t.Fatalf("got %+v", open)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAgentOfflineRespects15MinFloor(t *testing.T) {
|
||||
t.Parallel()
|
||||
eng, st, hostID := setupEngine(t)
|
||||
// Host's last_seen_at defaulted to NULL via CreateHost (enrolled but never
|
||||
// seen). Force a stale value for the test by direct DB update.
|
||||
if _, err := st.DB().Exec(
|
||||
`UPDATE hosts SET last_seen_at = ? WHERE id = ?`,
|
||||
time.Now().UTC().Add(-20*time.Minute).Format(time.RFC3339Nano), hostID,
|
||||
); err != nil {
|
||||
t.Fatalf("update last_seen_at: %v", err)
|
||||
}
|
||||
eng.handleHostOffline(context.Background(), hostID)
|
||||
open, _ := st.ListAlerts(context.Background(),
|
||||
store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 1 {
|
||||
t.Fatalf("expected agent_offline raised; got %d", len(open))
|
||||
}
|
||||
|
||||
// Bring back online — should auto-resolve.
|
||||
eng.handleHostOnline(context.Background(), hostID)
|
||||
open, _ = st.ListAlerts(context.Background(),
|
||||
store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 0 {
|
||||
t.Fatalf("expected agent_offline resolved; got %d", len(open))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAgentOfflineUnderFloorNoRaise(t *testing.T) {
|
||||
t.Parallel()
|
||||
eng, st, hostID := setupEngine(t)
|
||||
// last_seen_at is NULL from CreateHost (never touched). A nil
|
||||
// last_seen_at means the host was enrolled but never connected —
|
||||
// treat that as "now" for the floor check so we don't raise
|
||||
// immediately. handleHostOffline must skip the raise.
|
||||
eng.handleHostOffline(context.Background(), hostID)
|
||||
open, _ := st.ListAlerts(context.Background(),
|
||||
store.AlertFilter{Status: "open", HostID: hostID})
|
||||
if len(open) != 0 {
|
||||
t.Fatalf("expected no raise within 15-min floor; got %d", len(open))
|
||||
}
|
||||
}
|
||||
@@ -52,17 +52,14 @@ type JobKind string
|
||||
|
||||
// Allowed JobKind values. backup is operator/cron driven; init runs
|
||||
// once per host on first connect; forget/prune/check fire from the
|
||||
// server-side maintenance ticker; unlock and restore are operator-
|
||||
// only; diff is operator-only and read-only.
|
||||
// server-side maintenance ticker; unlock is operator-only.
|
||||
const (
|
||||
JobBackup JobKind = "backup"
|
||||
JobInit JobKind = "init"
|
||||
JobForget JobKind = "forget"
|
||||
JobPrune JobKind = "prune"
|
||||
JobCheck JobKind = "check"
|
||||
JobUnlock JobKind = "unlock"
|
||||
JobRestore JobKind = "restore"
|
||||
JobDiff JobKind = "diff"
|
||||
JobBackup JobKind = "backup"
|
||||
JobInit JobKind = "init"
|
||||
JobForget JobKind = "forget"
|
||||
JobPrune JobKind = "prune"
|
||||
JobCheck JobKind = "check"
|
||||
JobUnlock JobKind = "unlock"
|
||||
)
|
||||
|
||||
// JobStatus is the lifecycle state of a job.
|
||||
@@ -133,48 +130,6 @@ type CommandRunPayload struct {
|
||||
Tag string `json:"tag,omitempty"`
|
||||
ForgetGroups []ForgetGroup `json:"forget_groups,omitempty"`
|
||||
RequiresAdminCreds bool `json:"requires_admin_creds,omitempty"`
|
||||
|
||||
// Per-job bandwidth caps in KB/s. When nil, the agent uses the
|
||||
// host-wide caps it received via config.update. When non-nil,
|
||||
// the override wins for this job only — even a non-nil zero
|
||||
// pointer means "no cap for this job" (caller's explicit choice).
|
||||
BandwidthUpKBps *int `json:"bandwidth_up_kbps,omitempty"`
|
||||
BandwidthDownKBps *int `json:"bandwidth_down_kbps,omitempty"`
|
||||
|
||||
// Hooks run only for kind=backup. Server resolves source-group
|
||||
// hook → host default → empty before dispatching, so the agent
|
||||
// just executes whatever is here.
|
||||
PreHook string `json:"pre_hook,omitempty"`
|
||||
PostHook string `json:"post_hook,omitempty"`
|
||||
|
||||
// Restore is populated only for kind=restore. See RestorePayload
|
||||
// for the shape; nil for every other kind.
|
||||
Restore *RestorePayload `json:"restore,omitempty"`
|
||||
|
||||
// Diff is populated only for kind=diff. See DiffPayload for
|
||||
// shape; nil for every other kind.
|
||||
Diff *DiffPayload `json:"diff,omitempty"`
|
||||
}
|
||||
|
||||
// RestorePayload carries restore-specific arguments on a JobRestore
|
||||
// command.run. Paths are absolute paths inside the snapshot (same
|
||||
// shape restic accepts as positional args). When InPlace is true the
|
||||
// agent restores at root (`--target /`) and preserves uid/gid/mode;
|
||||
// otherwise it restores into TargetDir with --no-ownership so the
|
||||
// operator can inspect the files as the agent user.
|
||||
type RestorePayload struct {
|
||||
SnapshotID string `json:"snapshot_id"`
|
||||
Paths []string `json:"paths"`
|
||||
InPlace bool `json:"in_place"`
|
||||
TargetDir string `json:"target_dir,omitempty"` // ignored when in_place=true
|
||||
}
|
||||
|
||||
// DiffPayload carries snapshot-diff arguments on a JobDiff command.run.
|
||||
// SnapshotA / SnapshotB may be either short or long IDs; restic
|
||||
// accepts both.
|
||||
type DiffPayload struct {
|
||||
SnapshotA string `json:"snapshot_a"`
|
||||
SnapshotB string `json:"snapshot_b"`
|
||||
}
|
||||
|
||||
// CommandCancelPayload is the server → agent cancel signal.
|
||||
@@ -351,14 +306,6 @@ type ConfigUpdatePayload struct {
|
||||
RepoCredential string `json:"repo_credential,omitempty"` // sensitive (for rest server basic auth)
|
||||
HookShell string `json:"hook_shell,omitempty"`
|
||||
Slot string `json:"slot,omitempty"`
|
||||
|
||||
// Bandwidth caps in KB/s. Pointer semantics so the server can
|
||||
// disambiguate "no change in this push" (nil → omitted on the
|
||||
// wire) from "explicitly clear the cap" (zero or negative value).
|
||||
// Applied to every restic invocation as --limit-upload /
|
||||
// --limit-download. Per-job overrides ride on CommandRunPayload.
|
||||
BandwidthUpKBps *int `json:"bandwidth_up_kbps,omitempty"`
|
||||
BandwidthDownKBps *int `json:"bandwidth_down_kbps,omitempty"`
|
||||
}
|
||||
|
||||
// AgentUpdateAvailablePayload — informational only; the agent does
|
||||
@@ -369,37 +316,3 @@ type AgentUpdateAvailablePayload struct {
|
||||
PackageURL string `json:"package_url"` // apt repo / choco source
|
||||
Changelog string `json:"changelog,omitempty"`
|
||||
}
|
||||
|
||||
// TreeListRequestPayload is the body of a tree.list RPC. Used by the
|
||||
// restore wizard to lazy-load directory contents from a snapshot.
|
||||
//
|
||||
// The exchange is synchronous: the server marshals MsgTreeList with a
|
||||
// fresh Envelope.ID, sends to the agent, blocks on a channel keyed by
|
||||
// that ID. The agent runs `restic ls --json <SnapshotID> <Path>`,
|
||||
// emits direct children, and replies with MsgTreeListResult carrying
|
||||
// the same ID. The server-side handler matches on ID and forwards to
|
||||
// the waiting channel. See internal/server/ws/rpc.go for the helper.
|
||||
type TreeListRequestPayload struct {
|
||||
SnapshotID string `json:"snapshot_id"`
|
||||
Path string `json:"path"` // absolute path inside the snapshot, "/" for root
|
||||
}
|
||||
|
||||
// TreeListEntry is one direct child returned by a tree.list call.
|
||||
// Type is "dir" | "file" | "symlink"; size is best-effort (zero on
|
||||
// directories and symlinks).
|
||||
type TreeListEntry struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
// TreeListResultPayload is the reply to a tree.list. Error is set
|
||||
// when the agent couldn't fulfil the request (missing snapshot,
|
||||
// path doesn't exist, restic invocation failed); Entries is empty in
|
||||
// that case. A successful empty directory has Error="" + nil Entries.
|
||||
type TreeListResultPayload struct {
|
||||
SnapshotID string `json:"snapshot_id"`
|
||||
Path string `json:"path"`
|
||||
Entries []TreeListEntry `json:"entries,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
+13
-15
@@ -12,19 +12,18 @@ type MessageType string
|
||||
|
||||
// Agent → server message types.
|
||||
const (
|
||||
MsgHello MessageType = "hello"
|
||||
MsgHeartbeat MessageType = "heartbeat"
|
||||
MsgJobStarted MessageType = "job.started"
|
||||
MsgJobProgress MessageType = "job.progress"
|
||||
MsgJobFinished MessageType = "job.finished"
|
||||
MsgSnapshotsRpt MessageType = "snapshots.report"
|
||||
MsgRepoStats MessageType = "repo.stats"
|
||||
MsgLogStream MessageType = "log.stream"
|
||||
MsgScheduleAck MessageType = "schedule.ack"
|
||||
MsgScheduleFire MessageType = "schedule.fire" // agent: a local cron entry fired, please dispatch a job
|
||||
MsgCommandResult MessageType = "command.result" // ack for command.run
|
||||
MsgTreeListResult MessageType = "tree.list.result" // reply to a server-driven tree.list
|
||||
MsgError MessageType = "error"
|
||||
MsgHello MessageType = "hello"
|
||||
MsgHeartbeat MessageType = "heartbeat"
|
||||
MsgJobStarted MessageType = "job.started"
|
||||
MsgJobProgress MessageType = "job.progress"
|
||||
MsgJobFinished MessageType = "job.finished"
|
||||
MsgSnapshotsRpt MessageType = "snapshots.report"
|
||||
MsgRepoStats MessageType = "repo.stats"
|
||||
MsgLogStream MessageType = "log.stream"
|
||||
MsgScheduleAck MessageType = "schedule.ack"
|
||||
MsgScheduleFire MessageType = "schedule.fire" // agent: a local cron entry fired, please dispatch a job
|
||||
MsgCommandResult MessageType = "command.result" // ack for command.run
|
||||
MsgError MessageType = "error"
|
||||
)
|
||||
|
||||
// Server → agent message types.
|
||||
@@ -34,7 +33,6 @@ const (
|
||||
MsgScheduleSet MessageType = "schedule.set"
|
||||
MsgConfigUpdate MessageType = "config.update"
|
||||
MsgAgentUpdateAvail MessageType = "agent.update.available"
|
||||
MsgTreeList MessageType = "tree.list" // sync RPC: list a snapshot's children
|
||||
)
|
||||
|
||||
// Envelope is the framing for every WS message in either direction.
|
||||
@@ -78,7 +76,7 @@ type ErrorCode string
|
||||
const (
|
||||
ErrProtocolTooOld ErrorCode = "protocol_too_old"
|
||||
ErrProtocolTooNew ErrorCode = "protocol_too_new"
|
||||
ErrUnauthorized ErrorCode = "unauthorised"
|
||||
ErrUnauthorized ErrorCode = "unauthorized"
|
||||
ErrBadRequest ErrorCode = "bad_request"
|
||||
ErrInternal ErrorCode = "internal"
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
@@ -28,38 +27,22 @@ const (
|
||||
defaultKeyLen = 32
|
||||
)
|
||||
|
||||
// Cheap params used only when the binary is a `go test` binary
|
||||
// (testing.Testing() == true). Argon2id at production params costs
|
||||
// 300–500 ms per hash and dominates wall time on CI runners under
|
||||
// `-race`. Tests don't need real KDF strength — VerifyPassword reads
|
||||
// params from the encoded hash, so verifying a cheap-params hash
|
||||
// works the same way.
|
||||
const (
|
||||
testMemoryKiB = 8
|
||||
testIterations = 1
|
||||
testParallel = 1
|
||||
)
|
||||
|
||||
// HashPassword returns an argon2id-encoded string of the form
|
||||
//
|
||||
// $argon2id$v=19$m=...,t=...,p=...$<salt>$<hash>
|
||||
//
|
||||
// safe to store in a TEXT column. The salt is freshly random per call.
|
||||
func HashPassword(password string) (string, error) {
|
||||
mem, iter, par := uint32(defaultMemoryKiB), uint32(defaultIterations), uint8(defaultParallel)
|
||||
if testing.Testing() {
|
||||
mem, iter, par = testMemoryKiB, testIterations, testParallel
|
||||
}
|
||||
salt := make([]byte, defaultSaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", fmt.Errorf("auth: read salt: %w", err)
|
||||
}
|
||||
hash := argon2.IDKey([]byte(password), salt,
|
||||
iter, mem, par, defaultKeyLen)
|
||||
defaultIterations, defaultMemoryKiB, defaultParallel, defaultKeyLen)
|
||||
|
||||
return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
mem, iter, par,
|
||||
defaultMemoryKiB, defaultIterations, defaultParallel,
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
), nil
|
||||
@@ -73,7 +56,7 @@ func VerifyPassword(encoded, password string) error {
|
||||
parts := strings.Split(encoded, "$")
|
||||
// "$argon2id$v=...$m=...,t=...,p=...$<salt>$<hash>" → 6 parts (leading empty)
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
return errors.New("auth: unrecognised hash format")
|
||||
return errors.New("auth: unrecognized hash format")
|
||||
}
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// passwords, REST-server credentials, hook bodies, and any other
|
||||
// secret that lands in the SQLite store.
|
||||
//
|
||||
// The threat model is "defence in depth against a stolen DB file" —
|
||||
// The threat model is "defense in depth against a stolen DB file" —
|
||||
// not "an attacker with code execution can't read secrets at runtime."
|
||||
// We need the encryption key at runtime to do any actual work, so
|
||||
// anyone with a memory dump of the running server can extract it.
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Channel is the per-kind transport. Implementations live in
|
||||
// webhook.go / ntfy.go / smtp.go. Send must respect ctx (5s for HTTP,
|
||||
// 10s for SMTP) and never panic.
|
||||
type Channel interface {
|
||||
// Kind returns the kind string ("webhook", "ntfy", "smtp"). Used
|
||||
// for log enrichment and dispatcher routing.
|
||||
Kind() string
|
||||
|
||||
// Send delivers one payload. Returns (statusCode, latency, err).
|
||||
// statusCode is HTTP for HTTP channels, the SMTP final-line code
|
||||
// (e.g. 250) for SMTP, 0 if the call didn't reach a wire response.
|
||||
Send(ctx context.Context, p Payload) (statusCode int, latency time.Duration, err error)
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Hub fans Payload events out to every enabled channel and persists
|
||||
// the result to notification_log. One Hub per process; thread-safe.
|
||||
type Hub struct {
|
||||
store *store.Store
|
||||
aead *crypto.AEAD
|
||||
baseURL string // e.g. https://restic-manager.example
|
||||
msgIDDomain string // hostname extracted from baseURL for SMTP Message-ID
|
||||
}
|
||||
|
||||
// NewHub constructs a Hub. baseURL is the public root of the server
|
||||
// (used to build /alerts/<id> links and the SMTP Message-ID domain).
|
||||
func NewHub(st *store.Store, aead *crypto.AEAD, baseURL string) *Hub {
|
||||
return &Hub{
|
||||
store: st,
|
||||
aead: aead,
|
||||
baseURL: baseURL,
|
||||
msgIDDomain: extractDomain(baseURL),
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch fans out to every enabled channel. Best-effort — failures
|
||||
// are logged to notification_log but do not propagate to the caller.
|
||||
// Each channel runs in its own goroutine; Dispatch returns only when
|
||||
// all goroutines have settled, so the caller can block briefly for
|
||||
// the test-button case.
|
||||
func (h *Hub) Dispatch(ctx context.Context, p Payload) {
|
||||
chans, err := h.store.ListEnabledNotificationChannels(ctx)
|
||||
if err != nil {
|
||||
slog.Error("notification: list channels", "err", err)
|
||||
return
|
||||
}
|
||||
// Stamp the alert link if the caller left it empty.
|
||||
if p.Link == "" {
|
||||
p.Link = h.baseURL + "/alerts/" + p.AlertID
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, c := range chans {
|
||||
wg.Add(1)
|
||||
go func(c store.NotificationChannel) {
|
||||
defer wg.Done()
|
||||
h.send(ctx, c, p)
|
||||
}(c)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// DispatchOne fires a single channel — used by the "Send test
|
||||
// notification" button. Returns the log entry that was persisted so
|
||||
// the handler can render the result inline.
|
||||
func (h *Hub) DispatchOne(ctx context.Context, channelID string, p Payload) (store.NotificationLogEntry, error) {
|
||||
c, err := h.store.GetNotificationChannel(ctx, channelID)
|
||||
if err != nil {
|
||||
return store.NotificationLogEntry{}, err
|
||||
}
|
||||
if p.Link == "" {
|
||||
p.Link = h.baseURL + "/alerts/" + p.AlertID
|
||||
}
|
||||
return h.send(ctx, *c, p), nil
|
||||
}
|
||||
|
||||
// send builds the channel impl, delivers the payload, and persists a
|
||||
// notification_log row regardless of success or failure.
|
||||
func (h *Hub) send(ctx context.Context, c store.NotificationChannel, p Payload) store.NotificationLogEntry {
|
||||
ch, buildErr := h.buildChannel(c)
|
||||
logEntry := store.NotificationLogEntry{
|
||||
ID: newID(),
|
||||
ChannelID: c.ID,
|
||||
Event: string(p.Event),
|
||||
FiredAt: time.Now().UTC(),
|
||||
}
|
||||
if p.AlertID != "" {
|
||||
aid := p.AlertID
|
||||
logEntry.AlertID = &aid
|
||||
}
|
||||
if buildErr != nil {
|
||||
errStr := buildErr.Error()
|
||||
logEntry.OK = false
|
||||
logEntry.Error = &errStr
|
||||
_ = h.store.AppendNotificationLog(ctx, logEntry)
|
||||
return logEntry
|
||||
}
|
||||
|
||||
code, latency, sendErr := ch.Send(ctx, p)
|
||||
statusCode := code
|
||||
latencyMS := int(latency.Milliseconds())
|
||||
logEntry.StatusCode = &statusCode
|
||||
logEntry.LatencyMS = &latencyMS
|
||||
if sendErr != nil {
|
||||
errStr := sendErr.Error()
|
||||
logEntry.OK = false
|
||||
logEntry.Error = &errStr
|
||||
} else {
|
||||
logEntry.OK = true
|
||||
}
|
||||
if err := h.store.AppendNotificationLog(ctx, logEntry); err != nil {
|
||||
slog.Warn("notification: persist log", "err", err)
|
||||
}
|
||||
return logEntry
|
||||
}
|
||||
|
||||
// buildChannel decrypts the channel config and returns a concrete
|
||||
// Channel implementation for the channel's kind.
|
||||
func (h *Hub) buildChannel(row store.NotificationChannel) (Channel, error) {
|
||||
plain, err := h.aead.Decrypt(string(row.Config), []byte("notification-channel:"+row.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch row.Kind {
|
||||
case "webhook":
|
||||
var cfg WebhookConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewWebhookChannel(cfg), nil
|
||||
case "ntfy":
|
||||
var cfg NtfyConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dp := ""
|
||||
if row.DefaultPriority != nil {
|
||||
dp = *row.DefaultPriority
|
||||
}
|
||||
return NewNtfyChannel(cfg, dp), nil
|
||||
case "smtp":
|
||||
var cfg SMTPConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSMTPChannel(cfg, h.msgIDDomain), nil
|
||||
}
|
||||
return nil, errUnknownKind(row.Kind)
|
||||
}
|
||||
|
||||
// newID returns a 32-hex-char random identifier for notification_log rows.
|
||||
func newID() string {
|
||||
var b [16]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// extractDomain strips the scheme and path from baseURL, leaving only
|
||||
// the host[:port] component. Used as the right-hand side of SMTP
|
||||
// Message-IDs.
|
||||
func extractDomain(baseURL string) string {
|
||||
s := baseURL
|
||||
if i := indexOf(s, "://"); i >= 0 {
|
||||
s = s[i+3:]
|
||||
}
|
||||
if i := indexOf(s, "/"); i >= 0 {
|
||||
s = s[:i]
|
||||
}
|
||||
if s == "" {
|
||||
return "restic-manager.local"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// indexOf returns the index of the first occurrence of sub in s, or -1.
|
||||
func indexOf(s, sub string) int {
|
||||
for i := 0; i+len(sub) <= len(s); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type errUnknownKind string
|
||||
|
||||
func (e errUnknownKind) Error() string { return "notification: unknown kind: " + string(e) }
|
||||
@@ -1,99 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func setupHub(t *testing.T) (*Hub, *store.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
_ = crypto.GenerateKeyFile(keyPath)
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
return NewHub(st, aead, "https://rm.example"), st
|
||||
}
|
||||
|
||||
func TestHubDispatchRecordsLogEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
hub, st := setupHub(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg, _ := json.Marshal(WebhookConfig{URL: srv.URL})
|
||||
enc, err := hub.aead.Encrypt(cfg, []byte("notification-channel:test-ch"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
if err := st.CreateNotificationChannel(context.Background(), store.NotificationChannel{
|
||||
ID: "test-ch", Kind: "webhook", Name: "test", Enabled: true,
|
||||
Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create channel: %v", err)
|
||||
}
|
||||
|
||||
hub.Dispatch(context.Background(), Payload{
|
||||
Event: EventRaised,
|
||||
Severity: "warning",
|
||||
Kind: "backup_failed",
|
||||
HostName: "alfa-01",
|
||||
Message: "x",
|
||||
RaisedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
// Verify a log row landed with ok=1.
|
||||
var n int
|
||||
if err := st.DB().QueryRow(
|
||||
`SELECT COUNT(*) FROM notification_log WHERE channel_id = ? AND ok = 1`, "test-ch",
|
||||
).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected 1 log row, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubSkipsDisabledChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
hub, st := setupHub(t)
|
||||
|
||||
cfg, _ := json.Marshal(WebhookConfig{URL: "http://no-such-host.invalid"})
|
||||
enc, _ := hub.aead.Encrypt(cfg, []byte("notification-channel:dis"))
|
||||
_ = st.CreateNotificationChannel(context.Background(), store.NotificationChannel{
|
||||
ID: "dis", Kind: "webhook", Name: "off", Enabled: false,
|
||||
Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
hub.Dispatch(context.Background(), Payload{
|
||||
Event: EventRaised,
|
||||
AlertID: "x",
|
||||
Severity: "warning",
|
||||
Kind: "backup_failed",
|
||||
HostName: "h",
|
||||
Message: "m",
|
||||
RaisedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
var n int
|
||||
_ = st.DB().QueryRow(`SELECT COUNT(*) FROM notification_log`).Scan(&n)
|
||||
if n != 0 {
|
||||
t.Errorf("disabled channel produced log rows: %d", n)
|
||||
}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NtfyConfig is the per-channel JSON shape stored AEAD-encrypted in
|
||||
// notification_channels.config. AccessToken takes precedence over
|
||||
// (Username, Password) when both are set; supply one or the other for
|
||||
// self-hosted ntfy that requires auth.
|
||||
type NtfyConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
Topic string `json:"topic"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// NtfyChannel delivers alerts to an ntfy server using POST with
|
||||
// ntfy-specific headers (Title, Priority, Tags, Click). One instance
|
||||
// per configured channel row. Reused across sends — http.Client is
|
||||
// goroutine-safe.
|
||||
type NtfyChannel struct {
|
||||
cfg NtfyConfig
|
||||
defaultPriority string // "min"/"low"/"default"/"high"/"urgent" or ""
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewNtfyChannel builds an ntfy channel with a 5s http.Client timeout.
|
||||
// defaultPriority is the channel-configured fallback when no
|
||||
// severity-specific mapping applies; pass "" to use the built-in
|
||||
// fallbacks (4 for warning, 3 for everything else).
|
||||
func NewNtfyChannel(cfg NtfyConfig, defaultPriority string) *NtfyChannel {
|
||||
if cfg.ServerURL == "" {
|
||||
cfg.ServerURL = "https://ntfy.sh"
|
||||
}
|
||||
return &NtfyChannel{
|
||||
cfg: cfg,
|
||||
defaultPriority: defaultPriority,
|
||||
client: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Kind returns "ntfy" for log enrichment and dispatcher routing.
|
||||
func (c *NtfyChannel) Kind() string { return "ntfy" }
|
||||
|
||||
// Send delivers the payload as a plain-text POST to <server>/<topic>
|
||||
// with ntfy headers. Returns (statusCode, latency, err). 4xx/5xx
|
||||
// responses are returned as errors with the status code set.
|
||||
func (c *NtfyChannel) Send(ctx context.Context, p Payload) (int, time.Duration, error) {
|
||||
server := strings.TrimRight(c.cfg.ServerURL, "/")
|
||||
url := server + "/" + c.cfg.Topic
|
||||
|
||||
// Body carries the event verb so the body alone is unambiguous when
|
||||
// it shows up on a phone lockscreen without the title.
|
||||
body := p.Message
|
||||
switch p.Event {
|
||||
case EventResolved:
|
||||
body = "Resolved · " + p.Message
|
||||
case EventAcknowledged:
|
||||
body = "Acknowledged · " + p.Message
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(body))
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("ntfy: build request: %w", err)
|
||||
}
|
||||
|
||||
// Title prefix tracks the event so raise vs ack vs resolve are
|
||||
// visually distinct in the ntfy notification list.
|
||||
verb := "raised"
|
||||
switch p.Event {
|
||||
case EventAcknowledged:
|
||||
verb = "ack"
|
||||
case EventResolved:
|
||||
verb = "resolved"
|
||||
case EventTest:
|
||||
verb = "test"
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
req.Header.Set("Title", fmt.Sprintf("[%s · %s] %s %s", verb, p.Severity, p.HostName, p.Kind))
|
||||
req.Header.Set("Tags", verb+","+p.Severity+","+p.Kind)
|
||||
req.Header.Set("Priority", priorityForSeverity(p.Severity, c.defaultPriority))
|
||||
if p.Link != "" {
|
||||
req.Header.Set("Click", p.Link)
|
||||
}
|
||||
switch {
|
||||
case c.cfg.AccessToken != "":
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.AccessToken)
|
||||
case c.cfg.Username != "":
|
||||
creds := c.cfg.Username + ":" + c.cfg.Password
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(creds)))
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
res, err := c.client.Do(req)
|
||||
latency := time.Since(t0)
|
||||
if err != nil {
|
||||
return 0, latency, fmt.Errorf("ntfy: do: %w", err)
|
||||
}
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
// Drain body to keep the connection reusable.
|
||||
_, _ = io.Copy(io.Discard, res.Body)
|
||||
if res.StatusCode >= 400 {
|
||||
return res.StatusCode, latency, fmt.Errorf("ntfy: http %d", res.StatusCode)
|
||||
}
|
||||
return res.StatusCode, latency, nil
|
||||
}
|
||||
|
||||
// priorityForSeverity maps a severity string to an ntfy numeric priority
|
||||
// string. critical always returns "5" regardless of defaultPri. For
|
||||
// other severities, defaultPri is returned when non-empty, otherwise
|
||||
// "4" for warning and "3" for everything else.
|
||||
func priorityForSeverity(severity, defaultPri string) string {
|
||||
switch severity {
|
||||
case "critical":
|
||||
return "5"
|
||||
case "warning":
|
||||
if defaultPri != "" {
|
||||
return defaultPri
|
||||
}
|
||||
return "4"
|
||||
default:
|
||||
if defaultPri != "" {
|
||||
return defaultPri
|
||||
}
|
||||
return "3"
|
||||
}
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNtfySendsHeadersAndBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
gotTitle string
|
||||
gotPri string
|
||||
gotTags string
|
||||
gotClick string
|
||||
gotAuth string
|
||||
gotContentType string
|
||||
gotBody string
|
||||
)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotTitle = r.Header.Get("Title")
|
||||
gotPri = r.Header.Get("Priority")
|
||||
gotTags = r.Header.Get("Tags")
|
||||
gotClick = r.Header.Get("Click")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotContentType = r.Header.Get("Content-Type")
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
gotBody = string(b)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := NtfyConfig{
|
||||
ServerURL: srv.URL,
|
||||
Topic: "alerts",
|
||||
AccessToken: "tk1",
|
||||
}
|
||||
ch := NewNtfyChannel(cfg, "") // no default priority; critical must still be "5"
|
||||
|
||||
p := Payload{
|
||||
Event: EventRaised,
|
||||
AlertID: "01HZ",
|
||||
Severity: "critical",
|
||||
Kind: "check_failed",
|
||||
HostName: "alfa-01",
|
||||
Message: "errors found",
|
||||
RaisedAt: time.Now(),
|
||||
Link: "https://rm.example/a",
|
||||
}
|
||||
|
||||
code, _, err := ch.Send(t.Context(), p)
|
||||
if err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("want 200, got %d", code)
|
||||
}
|
||||
|
||||
if want := "[raised · critical] alfa-01 check_failed"; gotTitle != want {
|
||||
t.Errorf("Title: got %q want %q", gotTitle, want)
|
||||
}
|
||||
if gotPri != "5" {
|
||||
t.Errorf("Priority: got %q want \"5\"", gotPri)
|
||||
}
|
||||
if want := "raised,critical,check_failed"; gotTags != want {
|
||||
t.Errorf("Tags: got %q want %q", gotTags, want)
|
||||
}
|
||||
if gotClick != "https://rm.example/a" {
|
||||
t.Errorf("Click: got %q want %q", gotClick, "https://rm.example/a")
|
||||
}
|
||||
if want := "Bearer tk1"; gotAuth != want {
|
||||
t.Errorf("Authorization: got %q want %q", gotAuth, want)
|
||||
}
|
||||
if gotContentType != "text/plain" {
|
||||
t.Errorf("Content-Type: got %q want %q", gotContentType, "text/plain")
|
||||
}
|
||||
if gotBody != "errors found" {
|
||||
t.Errorf("body: got %q want %q", gotBody, "errors found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyDefaultPriorityRespected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// info + defaultPri="min" → "min"
|
||||
if got := priorityForSeverity("info", "min"); got != "min" {
|
||||
t.Errorf("info+min: got %q want \"min\"", got)
|
||||
}
|
||||
// critical → "5" regardless of default
|
||||
if got := priorityForSeverity("critical", "min"); got != "5" {
|
||||
t.Errorf("critical+min: got %q want \"5\"", got)
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
// Package notification owns the fan-out of alert events to operator-
|
||||
// configured channels. Three channels in v1: webhook, ntfy, smtp.
|
||||
// Each channel implements Channel.Send for one Payload at a time;
|
||||
// the Hub orchestrates fan-out, persists to notification_log.
|
||||
package notification
|
||||
|
||||
import "time"
|
||||
|
||||
// Event identifies the lifecycle hook this notification is for.
|
||||
type Event string
|
||||
|
||||
const (
|
||||
// EventRaised occurs when an alert is first raised.
|
||||
EventRaised Event = "alert.raised"
|
||||
// EventAcknowledged occurs when an alert is acknowledged.
|
||||
EventAcknowledged Event = "alert.acknowledged"
|
||||
// EventResolved occurs when an alert is resolved.
|
||||
EventResolved Event = "alert.resolved"
|
||||
// EventTest is used for test notifications.
|
||||
EventTest Event = "alert.test"
|
||||
)
|
||||
|
||||
// Payload is the per-event blob every channel renders into its own
|
||||
// shape. Severity maps to channel-specific priority (ntfy) or stays
|
||||
// in the body (webhook/smtp).
|
||||
type Payload struct {
|
||||
Event Event // alert.raised | … | alert.test
|
||||
AlertID string // ULID
|
||||
Severity string // info | warning | critical
|
||||
Kind string // backup_failed | …
|
||||
HostID string
|
||||
HostName string
|
||||
Message string
|
||||
RaisedAt time.Time
|
||||
Link string // Absolute URL to /alerts/<id>; built by Hub
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SMTPConfig holds the configuration for an SMTP notification channel.
|
||||
type SMTPConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Encryption string `json:"encryption"` // "starttls" | "tls" | "none"
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
}
|
||||
|
||||
// SMTPChannel delivers alert notifications via plain-text email.
|
||||
type SMTPChannel struct {
|
||||
cfg SMTPConfig
|
||||
// messageIDDomain holds the public base hostname of restic-manager so
|
||||
// Message-IDs include a stable right-hand-side. Falls back to
|
||||
// "restic-manager.local" when unset.
|
||||
messageIDDomain string
|
||||
}
|
||||
|
||||
// NewSMTPChannel builds an SMTP channel. messageIDDomain comes from
|
||||
// cfg.Cfg.BaseURL — caller passes it through.
|
||||
func NewSMTPChannel(cfg SMTPConfig, messageIDDomain string) *SMTPChannel {
|
||||
if messageIDDomain == "" {
|
||||
messageIDDomain = "restic-manager.local"
|
||||
}
|
||||
return &SMTPChannel{cfg: cfg, messageIDDomain: messageIDDomain}
|
||||
}
|
||||
|
||||
// Kind returns "smtp".
|
||||
func (c *SMTPChannel) Kind() string { return "smtp" }
|
||||
|
||||
// Send delivers the payload as a plain-text email via SMTP.
|
||||
// Returns (250, latency, nil) on success.
|
||||
func (c *SMTPChannel) Send(ctx context.Context, p Payload) (int, time.Duration, error) {
|
||||
t0 := time.Now()
|
||||
addr := fmt.Sprintf("%s:%d", c.cfg.Host, c.cfg.Port)
|
||||
|
||||
// Dial respects ctx (we use net.Dialer).
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
rawConn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
var client *smtp.Client
|
||||
switch strings.ToLower(c.cfg.Encryption) {
|
||||
case "tls":
|
||||
conn := tls.Client(rawConn, &tls.Config{ServerName: c.cfg.Host, MinVersion: tls.VersionTLS12})
|
||||
client, err = smtp.NewClient(conn, c.cfg.Host)
|
||||
case "starttls", "":
|
||||
client, err = smtp.NewClient(rawConn, c.cfg.Host)
|
||||
if err == nil {
|
||||
err = client.StartTLS(&tls.Config{ServerName: c.cfg.Host, MinVersion: tls.VersionTLS12})
|
||||
}
|
||||
case "none":
|
||||
client, err = smtp.NewClient(rawConn, c.cfg.Host)
|
||||
default:
|
||||
_ = rawConn.Close()
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: unknown encryption %q", c.cfg.Encryption)
|
||||
}
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: handshake: %w", err)
|
||||
}
|
||||
defer func() { _ = client.Quit() }()
|
||||
|
||||
if c.cfg.Username != "" {
|
||||
auth := smtp.PlainAuth("", c.cfg.Username, c.cfg.Password, c.cfg.Host)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: auth: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(extractAddr(c.cfg.From)); err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: MAIL FROM: %w", err)
|
||||
}
|
||||
if err := client.Rcpt(c.cfg.To); err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: RCPT TO: %w", err)
|
||||
}
|
||||
wc, err := client.Data()
|
||||
if err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: DATA: %w", err)
|
||||
}
|
||||
msg := buildEmailBody(c.cfg, c.messageIDDomain, p)
|
||||
if _, err := wc.Write(msg); err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: write: %w", err)
|
||||
}
|
||||
if err := wc.Close(); err != nil {
|
||||
return 0, time.Since(t0), fmt.Errorf("smtp: close DATA: %w", err)
|
||||
}
|
||||
|
||||
return 250, time.Since(t0), nil
|
||||
}
|
||||
|
||||
// extractAddr pulls the bare email out of a "Name <addr@host>" form.
|
||||
func extractAddr(s string) string {
|
||||
if i, j := strings.LastIndex(s, "<"), strings.LastIndex(s, ">"); i >= 0 && j > i {
|
||||
return s[i+1 : j]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// buildEmailBody assembles the RFC 5322 message bytes per the spec.
|
||||
// Plain text only; subject hardcoded.
|
||||
func buildEmailBody(cfg SMTPConfig, msgIDDomain string, p Payload) []byte {
|
||||
var b strings.Builder
|
||||
// Subject prefix tracks the event verb so raise vs ack vs resolve
|
||||
// are visually distinct in the inbox (and threaded by Message-ID).
|
||||
verb := "raised"
|
||||
switch p.Event {
|
||||
case EventAcknowledged:
|
||||
verb = "ack"
|
||||
case EventResolved:
|
||||
verb = "resolved"
|
||||
case EventTest:
|
||||
verb = "test"
|
||||
}
|
||||
b.WriteString("From: " + cfg.From + "\r\n")
|
||||
b.WriteString("To: " + cfg.To + "\r\n")
|
||||
b.WriteString(fmt.Sprintf("Subject: [restic-manager] [%s · %s] %s: %s\r\n", verb, p.Severity, p.HostName, p.Kind))
|
||||
b.WriteString("Date: " + p.RaisedAt.UTC().Format(time.RFC1123Z) + "\r\n")
|
||||
b.WriteString("Message-ID: <" + p.AlertID + "@" + msgIDDomain + ">\r\n")
|
||||
b.WriteString("MIME-Version: 1.0\r\n")
|
||||
b.WriteString("Content-Type: text/plain; charset=utf-8\r\n")
|
||||
b.WriteString("\r\n")
|
||||
b.WriteString(p.Message + "\r\n\r\n")
|
||||
b.WriteString("—\r\n")
|
||||
b.WriteString("Raised at: " + p.RaisedAt.UTC().Format(time.RFC3339) + "\r\n")
|
||||
b.WriteString("Severity: " + p.Severity + "\r\n")
|
||||
b.WriteString("Host: " + p.HostName + "\r\n")
|
||||
b.WriteString("Kind: " + p.Kind + "\r\n")
|
||||
if p.Link != "" {
|
||||
b.WriteString("\r\nOpen in restic-manager:\r\n")
|
||||
b.WriteString(p.Link + "\r\n")
|
||||
}
|
||||
b.WriteString("\r\n(This message was sent by restic-manager. Acknowledge or resolve in the UI.)\r\n")
|
||||
return []byte(b.String())
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeSMTPServer accepts a single connection, runs the minimal SMTP
|
||||
// dialogue (HELO/EHLO, MAIL FROM, RCPT TO, DATA, QUIT) and stores
|
||||
// what came across the wire. Plain (no TLS) — we test the protocol
|
||||
// shape, not crypto.
|
||||
type fakeSMTPServer struct {
|
||||
mu sync.Mutex
|
||||
mailFrom string
|
||||
rcptTo string
|
||||
data string
|
||||
authed bool
|
||||
}
|
||||
|
||||
func startFakeSMTP(t *testing.T) (string, *fakeSMTPServer) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
srv := &fakeSMTPServer{}
|
||||
t.Cleanup(func() { _ = ln.Close() })
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
readLine := func() string {
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(buf[:n])
|
||||
}
|
||||
write := func(s string) { _, _ = conn.Write([]byte(s)) }
|
||||
|
||||
write("220 fake.smtp ESMTP\r\n")
|
||||
for {
|
||||
line := readLine()
|
||||
if line == "" {
|
||||
return
|
||||
}
|
||||
cmd := strings.ToUpper(strings.TrimSpace(line))
|
||||
switch {
|
||||
case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"):
|
||||
write("250-fake.smtp\r\n250 AUTH PLAIN\r\n")
|
||||
case strings.HasPrefix(cmd, "AUTH "):
|
||||
srv.mu.Lock()
|
||||
srv.authed = true
|
||||
srv.mu.Unlock()
|
||||
write("235 OK\r\n")
|
||||
case strings.HasPrefix(cmd, "MAIL FROM"):
|
||||
srv.mu.Lock()
|
||||
srv.mailFrom = strings.TrimSpace(strings.TrimPrefix(line, "MAIL FROM:"))
|
||||
srv.mu.Unlock()
|
||||
write("250 OK\r\n")
|
||||
case strings.HasPrefix(cmd, "RCPT TO"):
|
||||
srv.mu.Lock()
|
||||
srv.rcptTo = strings.TrimSpace(strings.TrimPrefix(line, "RCPT TO:"))
|
||||
srv.mu.Unlock()
|
||||
write("250 OK\r\n")
|
||||
case cmd == "DATA":
|
||||
write("354 OK\r\n")
|
||||
// read until "\r\n.\r\n"
|
||||
var data strings.Builder
|
||||
for {
|
||||
chunk := readLine()
|
||||
if chunk == "" {
|
||||
break
|
||||
}
|
||||
data.WriteString(chunk)
|
||||
if strings.Contains(data.String(), "\r\n.\r\n") {
|
||||
break
|
||||
}
|
||||
}
|
||||
srv.mu.Lock()
|
||||
srv.data = data.String()
|
||||
srv.mu.Unlock()
|
||||
write("250 OK\r\n")
|
||||
case cmd == "QUIT":
|
||||
write("221 bye\r\n")
|
||||
return
|
||||
default:
|
||||
write("500 unknown\r\n")
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ln.Addr().String(), srv
|
||||
}
|
||||
|
||||
func TestSMTPSendsExpectedHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr, srv := startFakeSMTP(t)
|
||||
host, port := splitHostPort(addr)
|
||||
|
||||
ch := NewSMTPChannel(SMTPConfig{
|
||||
Host: host, Port: port, Encryption: "none",
|
||||
Username: "u", Password: "p",
|
||||
From: "Restic-Manager <alerts@example.com>",
|
||||
To: "ops@example.com",
|
||||
}, "rm.example")
|
||||
|
||||
_, _, err := ch.Send(context.Background(), Payload{
|
||||
Event: EventRaised, AlertID: "01ABC",
|
||||
Severity: "warning", Kind: "backup_failed",
|
||||
HostName: "alfa-01", Message: "Backup failed: 401",
|
||||
RaisedAt: time.Date(2026, 5, 4, 15, 42, 1, 0, time.UTC),
|
||||
Link: "https://rm.example/alerts/01ABC",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
if !srv.authed {
|
||||
t.Errorf("AUTH never sent")
|
||||
}
|
||||
if !strings.Contains(srv.mailFrom, "alerts@example.com") {
|
||||
t.Errorf("MAIL FROM: %q", srv.mailFrom)
|
||||
}
|
||||
if !strings.Contains(srv.rcptTo, "ops@example.com") {
|
||||
t.Errorf("RCPT TO: %q", srv.rcptTo)
|
||||
}
|
||||
if !strings.Contains(srv.data, "Subject: [restic-manager] [raised · warning] alfa-01: backup_failed") {
|
||||
t.Errorf("subject missing or wrong: %q", srv.data)
|
||||
}
|
||||
if !strings.Contains(srv.data, "Message-ID: <01ABC@rm.example>") {
|
||||
t.Errorf("Message-ID wrong: %q", srv.data)
|
||||
}
|
||||
if !strings.Contains(srv.data, "Backup failed: 401") {
|
||||
t.Errorf("body missing: %q", srv.data)
|
||||
}
|
||||
}
|
||||
|
||||
func splitHostPort(addr string) (string, int) {
|
||||
host, portStr, _ := net.SplitHostPort(addr)
|
||||
var port int
|
||||
for _, r := range portStr {
|
||||
port = port*10 + int(r-'0')
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WebhookConfig is the per-channel JSON shape stored AEAD-encrypted
|
||||
// in notification_channels.config.
|
||||
type WebhookConfig struct {
|
||||
URL string `json:"url"`
|
||||
BearerToken string `json:"bearer_token,omitempty"`
|
||||
HeaderName string `json:"header_name,omitempty"`
|
||||
HeaderValue string `json:"header_value,omitempty"`
|
||||
}
|
||||
|
||||
// WebhookChannel is the HTTP-POST channel. One per configured channel
|
||||
// row. Reused across sends — the http.Client is goroutine-safe.
|
||||
type WebhookChannel struct {
|
||||
cfg WebhookConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewWebhookChannel builds a webhook with a 5s overall timeout enforced
|
||||
// by the http.Client; ctx in Send is layered on top for caller-driven
|
||||
// cancel.
|
||||
func NewWebhookChannel(cfg WebhookConfig) *WebhookChannel {
|
||||
return &WebhookChannel{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Kind returns "webhook" for log enrichment and dispatcher routing.
|
||||
func (c *WebhookChannel) Kind() string { return "webhook" }
|
||||
|
||||
// webhookBody is the wire-stable envelope. Documented in the spec; do
|
||||
// not reorder fields freely — operators write switch statements on
|
||||
// "event" and "severity".
|
||||
type webhookBody struct {
|
||||
Event string `json:"event"`
|
||||
AlertID string `json:"alert_id"`
|
||||
Severity string `json:"severity"`
|
||||
Kind string `json:"kind"`
|
||||
HostID string `json:"host_id"`
|
||||
HostName string `json:"host_name"`
|
||||
Message string `json:"message"`
|
||||
RaisedAt string `json:"raised_at"`
|
||||
Link string `json:"link"`
|
||||
}
|
||||
|
||||
// Send delivers the payload as a JSON POST. Returns (statusCode, latency, err).
|
||||
// 4xx/5xx responses are returned as errors with the status code set.
|
||||
func (c *WebhookChannel) Send(ctx context.Context, p Payload) (int, time.Duration, error) {
|
||||
body := webhookBody{
|
||||
Event: string(p.Event), AlertID: p.AlertID,
|
||||
Severity: p.Severity, Kind: p.Kind,
|
||||
HostID: p.HostID, HostName: p.HostName,
|
||||
Message: p.Message,
|
||||
RaisedAt: p.RaisedAt.UTC().Format(time.RFC3339Nano),
|
||||
Link: p.Link,
|
||||
}
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("webhook: marshal body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.URL, bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("webhook: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.cfg.BearerToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.BearerToken)
|
||||
}
|
||||
if c.cfg.HeaderName != "" {
|
||||
req.Header.Set(c.cfg.HeaderName, c.cfg.HeaderValue)
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
res, err := c.client.Do(req)
|
||||
latency := time.Since(t0)
|
||||
if err != nil {
|
||||
return 0, latency, fmt.Errorf("webhook: do: %w", err)
|
||||
}
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
// Drain body — keep the connection reusable.
|
||||
_, _ = io.Copy(io.Discard, res.Body)
|
||||
if res.StatusCode >= 400 {
|
||||
return res.StatusCode, latency, fmt.Errorf("webhook: http %d", res.StatusCode)
|
||||
}
|
||||
return res.StatusCode, latency, nil
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebhookSendsCorrectPayloadAndHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
var got webhookBody
|
||||
var auth, custom string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth = r.Header.Get("Authorization")
|
||||
custom = r.Header.Get("X-Test")
|
||||
_ = json.NewDecoder(r.Body).Decode(&got)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := NewWebhookChannel(WebhookConfig{
|
||||
URL: srv.URL, BearerToken: "tok-123",
|
||||
HeaderName: "X-Test", HeaderValue: "yes",
|
||||
})
|
||||
code, _, err := ch.Send(context.Background(), Payload{
|
||||
Event: EventRaised, AlertID: "01K",
|
||||
Severity: "warning", Kind: "backup_failed",
|
||||
HostID: "h1", HostName: "alfa-01",
|
||||
Message: "Backup failed",
|
||||
RaisedAt: time.Date(2026, 5, 4, 15, 42, 1, 0, time.UTC),
|
||||
Link: "https://rm.example/alerts/01K",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
if code != 200 {
|
||||
t.Errorf("status: %d", code)
|
||||
}
|
||||
if got.Event != "alert.raised" || got.Kind != "backup_failed" || got.Message != "Backup failed" {
|
||||
t.Errorf("body: %+v", got)
|
||||
}
|
||||
if auth != "Bearer tok-123" {
|
||||
t.Errorf("auth: %q", auth)
|
||||
}
|
||||
if custom != "yes" {
|
||||
t.Errorf("custom header: %q", custom)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookReturnsErrorOn4xx(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer srv.Close()
|
||||
ch := NewWebhookChannel(WebhookConfig{URL: srv.URL})
|
||||
code, _, err := ch.Send(context.Background(), Payload{Event: EventRaised})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401")
|
||||
}
|
||||
if code != 401 {
|
||||
t.Errorf("code: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRespectsCtxTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer srv.Close()
|
||||
ch := NewWebhookChannel(WebhookConfig{URL: srv.URL})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, err := ch.Send(ctx, Payload{Event: EventRaised})
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package restic
|
||||
|
||||
import "syscall"
|
||||
|
||||
var sigterm = syscall.SIGTERM
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package restic
|
||||
|
||||
import "os"
|
||||
|
||||
// Windows has no SIGTERM. The closest equivalent is os.Interrupt
|
||||
// (CTRL_BREAK_EVENT), but Go's exec.Cmd.Process.Signal() on Windows
|
||||
// only supports os.Kill — sending anything else returns an error and
|
||||
// no signal is delivered. Fall back to os.Kill so Cancel still works
|
||||
// (immediate force-kill); WaitDelay is unused but harmless.
|
||||
var sigterm = os.Kill
|
||||
@@ -1,140 +0,0 @@
|
||||
package restic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LsEntry is one node from `restic ls --json`. Restic emits these as
|
||||
// line-delimited JSON; we keep only the fields the restore wizard
|
||||
// needs.
|
||||
type LsEntry struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Struct string `json:"struct_type,omitempty"`
|
||||
}
|
||||
|
||||
// ListTreeChildren runs `restic ls --json <snapshot> <dirPath>` and
|
||||
// returns only the direct children of dirPath. Restic ls is recursive
|
||||
// by default, so we filter post-hoc — for a typical interactive
|
||||
// drill-down ("expand /etc/nginx") the subtree is small (a few KB of
|
||||
// JSON); for huge subtrees this is suboptimal but correct.
|
||||
//
|
||||
// The first emitted line is restic's "snapshot" preamble (struct_type
|
||||
// = "snapshot") which we discard. Subsequent lines are nodes; we
|
||||
// match on path equal to dirPath + "/" + name (with normalisation so
|
||||
// trailing slashes don't break the comparison).
|
||||
//
|
||||
// dirPath="" or "/" lists the snapshot root.
|
||||
func (e Env) ListTreeChildren(ctx context.Context, snapshotID, dirPath string) ([]LsEntry, error) {
|
||||
if snapshotID == "" {
|
||||
return nil, fmt.Errorf("restic ls: snapshot id required")
|
||||
}
|
||||
parent := normalizeTreePath(dirPath)
|
||||
|
||||
args := []string{"ls", "--json", snapshotID}
|
||||
if parent != "/" {
|
||||
args = append(args, parent)
|
||||
}
|
||||
cmd := e.resticCmd(ctx, args...)
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("restic ls: stdout pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("restic ls: start: %w", err)
|
||||
}
|
||||
|
||||
out, parseErr := parseLsChildren(stdout, parent)
|
||||
|
||||
werr := cmd.Wait()
|
||||
if werr != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(werr, &ee) {
|
||||
return nil, fmt.Errorf("restic ls: exit %d: %s",
|
||||
ee.ExitCode(), strings.TrimSpace(stderr.String()))
|
||||
}
|
||||
return nil, fmt.Errorf("restic ls: %w", werr)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseLsChildren reads line-delimited JSON from r and returns nodes
|
||||
// whose Path is a direct child of parent. Exposed for testing.
|
||||
func parseLsChildren(r io.Reader, parent string) ([]LsEntry, error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
var out []LsEntry
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
var entry LsEntry
|
||||
if err := json.Unmarshal(line, &entry); err != nil {
|
||||
return nil, fmt.Errorf("restic ls: parse line: %w", err)
|
||||
}
|
||||
// Skip the snapshot preamble and any future struct_type
|
||||
// entries we don't care about.
|
||||
if entry.Struct == "snapshot" || entry.Path == "" {
|
||||
continue
|
||||
}
|
||||
if isDirectChild(entry.Path, parent) {
|
||||
out = append(out, entry)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("restic ls: read output: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// normalizeTreePath turns "" / "/" / "/etc/" / "etc" all into a
|
||||
// canonical absolute form with a leading slash and no trailing slash
|
||||
// (except the root, which is "/" alone).
|
||||
func normalizeTreePath(p string) string {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" || p == "/" {
|
||||
return "/"
|
||||
}
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
cleaned := path.Clean(p)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// isDirectChild reports whether childPath is a direct child of parent.
|
||||
// "/etc/nginx" is a direct child of "/etc"; "/etc/nginx/conf" is not.
|
||||
// "/etc" is a direct child of "/".
|
||||
func isDirectChild(childPath, parent string) bool {
|
||||
cp := normalizeTreePath(childPath)
|
||||
pp := normalizeTreePath(parent)
|
||||
if pp == "/" {
|
||||
// Direct children of root: exactly one slash-delimited segment.
|
||||
return cp != "/" && strings.Count(cp, "/") == 1
|
||||
}
|
||||
// Must start with parent + "/" and have no further slashes.
|
||||
prefix := pp + "/"
|
||||
if !strings.HasPrefix(cp, prefix) {
|
||||
return false
|
||||
}
|
||||
rest := cp[len(prefix):]
|
||||
return rest != "" && !strings.Contains(rest, "/")
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package restic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// realistic restic ls --json output sample. First line is the
|
||||
// snapshot preamble, subsequent lines are nodes. Trimmed to a few
|
||||
// entries that exercise depth filtering.
|
||||
const sampleLsOutput = `{"struct_type":"snapshot","time":"2026-05-04T09:14:00Z","id":"f3a7b2c1"}
|
||||
{"name":"etc","type":"dir","path":"/etc","permissions":"drwxr-xr-x","struct_type":"node"}
|
||||
{"name":"nginx","type":"dir","path":"/etc/nginx","permissions":"drwxr-xr-x","struct_type":"node"}
|
||||
{"name":"nginx.conf","type":"file","path":"/etc/nginx/nginx.conf","size":2400,"struct_type":"node"}
|
||||
{"name":"sites-available","type":"dir","path":"/etc/nginx/sites-available","struct_type":"node"}
|
||||
{"name":"alfa.conf","type":"file","path":"/etc/nginx/sites-available/alfa.conf","size":3100,"struct_type":"node"}
|
||||
{"name":"default.conf","type":"file","path":"/etc/nginx/sites-available/default.conf","size":2900,"struct_type":"node"}
|
||||
`
|
||||
|
||||
func TestParseLsChildrenAtRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/")
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("entries: got %d (%+v), want 1", len(entries), entries)
|
||||
}
|
||||
if entries[0].Name != "etc" || entries[0].Path != "/etc" || entries[0].Type != "dir" {
|
||||
t.Fatalf("entry: %+v", entries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLsChildrenAtEtc(t *testing.T) {
|
||||
t.Parallel()
|
||||
entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc")
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("entries: got %d, want 1 (just nginx, not nested children)", len(entries))
|
||||
}
|
||||
if entries[0].Name != "nginx" {
|
||||
t.Fatalf("entry: %+v", entries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLsChildrenAtNginx(t *testing.T) {
|
||||
t.Parallel()
|
||||
entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc/nginx")
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("entries: got %d (%+v), want 2 (nginx.conf + sites-available, not nested)",
|
||||
len(entries), entries)
|
||||
}
|
||||
gotNames := []string{entries[0].Name, entries[1].Name}
|
||||
want := map[string]bool{"nginx.conf": true, "sites-available": true}
|
||||
for _, n := range gotNames {
|
||||
if !want[n] {
|
||||
t.Errorf("unexpected name %q in result", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLsChildrenAtSitesAvailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc/nginx/sites-available")
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("entries: got %d, want 2", len(entries))
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.Type != "file" {
|
||||
t.Errorf("expected file type, got %q on %q", e.Type, e.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTreePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct{ in, want string }{
|
||||
{"", "/"},
|
||||
{"/", "/"},
|
||||
{"/etc", "/etc"},
|
||||
{"/etc/", "/etc"},
|
||||
{"etc/nginx", "/etc/nginx"},
|
||||
{"/etc//nginx", "/etc/nginx"},
|
||||
{"/etc/./nginx", "/etc/nginx"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := normalizeTreePath(c.in)
|
||||
if got != c.want {
|
||||
t.Errorf("normalizeTreePath(%q): got %q, want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDirectChild(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
child, parent string
|
||||
want bool
|
||||
}{
|
||||
{"/etc", "/", true},
|
||||
{"/etc/nginx", "/", false},
|
||||
{"/etc/nginx", "/etc", true},
|
||||
{"/etc/nginx/conf", "/etc", false},
|
||||
{"/etc/nginx/conf", "/etc/nginx", true},
|
||||
{"/etc", "/etc", false},
|
||||
{"/etcc", "/etc", false}, // prefix match guard
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := isDirectChild(c.child, c.parent)
|
||||
if got != c.want {
|
||||
t.Errorf("isDirectChild(%q, %q): got %v, want %v",
|
||||
c.child, c.parent, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
package restic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RestoreStatus mirrors the JSON `status` lines `restic restore --json`
|
||||
// emits while restoring. Field names track restic's wire format; we
|
||||
// project a subset (the rest are cosmetic).
|
||||
type RestoreStatus struct {
|
||||
MessageType string `json:"message_type"`
|
||||
SecondsElapsed int64 `json:"seconds_elapsed"`
|
||||
PercentDone float64 `json:"percent_done"`
|
||||
TotalFiles int64 `json:"total_files"`
|
||||
FilesRestored int64 `json:"files_restored"`
|
||||
FilesSkipped int64 `json:"files_skipped"`
|
||||
TotalBytes int64 `json:"total_bytes"`
|
||||
BytesRestored int64 `json:"bytes_restored"`
|
||||
BytesSkipped int64 `json:"bytes_skipped"`
|
||||
}
|
||||
|
||||
// RestoreSummary is the final summary line emitted after a successful
|
||||
// restore. Newer restic prints it; older clients leave us with no
|
||||
// summary, in which case the agent skips the stats and the live UI
|
||||
// just sees percent reach 100%.
|
||||
type RestoreSummary struct {
|
||||
MessageType string `json:"message_type"`
|
||||
SecondsElapsed int64 `json:"seconds_elapsed"`
|
||||
TotalFiles int64 `json:"total_files"`
|
||||
FilesRestored int64 `json:"files_restored"`
|
||||
FilesSkipped int64 `json:"files_skipped"`
|
||||
TotalBytes int64 `json:"total_bytes"`
|
||||
BytesRestored int64 `json:"bytes_restored"`
|
||||
BytesSkipped int64 `json:"bytes_skipped"`
|
||||
}
|
||||
|
||||
// RunRestore executes `restic restore <snapshotID> --target <dir>
|
||||
// [--include <p>...]` with --json and pumps progress events into
|
||||
// handle. paths is the operator-selected list (each becomes an
|
||||
// `--include` flag); preserveOwner controls --no-ownership.
|
||||
//
|
||||
// inPlace toggles target semantics:
|
||||
// - true → target is "/" and ownership is preserved
|
||||
// - false → target is targetDir and --no-ownership is passed
|
||||
//
|
||||
// targetDir is created on demand by restic itself.
|
||||
func (e Env) RunRestore(ctx context.Context, snapshotID string, paths []string, inPlace bool, targetDir string, handle LineHandler) (*RestoreSummary, error) {
|
||||
if snapshotID == "" {
|
||||
return nil, fmt.Errorf("restic restore: snapshot id required")
|
||||
}
|
||||
if !inPlace && targetDir == "" {
|
||||
return nil, fmt.Errorf("restic restore: target dir required for non-in-place restore")
|
||||
}
|
||||
|
||||
args := []string{"restore", "--json", snapshotID}
|
||||
target := targetDir
|
||||
if inPlace {
|
||||
target = "/"
|
||||
} else {
|
||||
// Expand $HOME / ${HOME} / leading ~/ in the operator-supplied
|
||||
// path, using the agent's own HOME (typically /root for the
|
||||
// User=root unit). The expansion runs agent-side so the
|
||||
// operator can specify a portable default like
|
||||
// $HOME/rm-restore/<job-id>/ in the wizard without the server
|
||||
// needing to know which user the agent runs as.
|
||||
target = expandHome(target)
|
||||
// Ensure the target directory exists. Restic itself creates
|
||||
// missing leaves but won't traverse multiple missing levels
|
||||
// (and we don't want the operator to have to pre-create the
|
||||
// per-job subdir). 0700 keeps the data root-only — the agent
|
||||
// runs as root, and operators who want a different mode can
|
||||
// chmod after the fact. If MkdirAll fails (operator typed a
|
||||
// path inside a read-only sandbox mount, ENOSPC, etc.) we
|
||||
// surface a clean error rather than letting restic fail with
|
||||
// something cryptic.
|
||||
if err := os.MkdirAll(target, 0o700); err != nil {
|
||||
return nil, fmt.Errorf("restic restore: prepare target %q: %w", target, err)
|
||||
}
|
||||
}
|
||||
args = append(args, "--target", target)
|
||||
// --no-ownership is nominally a restic 0.17+ flag, but at least
|
||||
// one downstream 0.18.1 build still rejects it. We rely on a
|
||||
// runtime probe captured at agent startup (see
|
||||
// SupportsRestoreNoOwnership) rather than version sniffing.
|
||||
// In-place restores always preserve ownership — that's the whole
|
||||
// point of in-place — so we only add the flag for new-dir mode.
|
||||
if !inPlace && e.SupportsRestoreNoOwnership {
|
||||
args = append(args, "--no-ownership")
|
||||
}
|
||||
for _, p := range paths {
|
||||
args = append(args, "--include", p)
|
||||
}
|
||||
|
||||
cmd := e.resticCmd(ctx, args...)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("restic restore: stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("restic restore: stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("restic restore: start: %w", err)
|
||||
}
|
||||
|
||||
var summary *RestoreSummary
|
||||
done := make(chan error, 2)
|
||||
go func() { done <- pumpRestoreStdout(stdout, handle, &summary) }()
|
||||
go func() { done <- pumpStderr(stderr, handle) }()
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-done; err != nil && handle != nil {
|
||||
handle("event", fmt.Sprintf("pump error: %v", err), nil)
|
||||
}
|
||||
}
|
||||
werr := cmd.Wait()
|
||||
if werr != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(werr, &ee) {
|
||||
return summary, fmt.Errorf("restic restore: exit %d", ee.ExitCode())
|
||||
}
|
||||
return summary, fmt.Errorf("restic restore: %w", werr)
|
||||
}
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// pumpRestoreStdout is the restore variant of pumpStdout: it emits
|
||||
// `event` lines for the parsed status/summary objects (so the runner
|
||||
// can shape them into job.progress) and forwards everything else as
|
||||
// stdout — but unlike backup we include the raw status JSON in
|
||||
// log.stream too because restore is short and the live log audience
|
||||
// genuinely benefits from the per-file traffic. Actually — we mirror
|
||||
// backup's behaviour and DROP raw status lines from log.stream
|
||||
// (they'd drown the log on a fast restore); the progress envelope
|
||||
// covers them.
|
||||
func pumpRestoreStdout(r io.Reader, handle LineHandler, summary **RestoreSummary) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if handle == nil {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(line, "{") {
|
||||
handle("stdout", line, nil)
|
||||
continue
|
||||
}
|
||||
var probe struct {
|
||||
MessageType string `json:"message_type"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(line), &probe); err != nil {
|
||||
handle("stdout", line, nil)
|
||||
continue
|
||||
}
|
||||
switch probe.MessageType {
|
||||
case "status":
|
||||
var ev RestoreStatus
|
||||
if json.Unmarshal([]byte(line), &ev) == nil {
|
||||
// Don't tee status lines to log.stream — too chatty.
|
||||
handle("event", line, ev)
|
||||
continue
|
||||
}
|
||||
case "summary":
|
||||
var ev RestoreSummary
|
||||
if json.Unmarshal([]byte(line), &ev) == nil {
|
||||
if summary != nil {
|
||||
s := ev
|
||||
*summary = &s
|
||||
}
|
||||
handle("event", line, ev)
|
||||
continue
|
||||
}
|
||||
case "verbose_status":
|
||||
handle("event", line, nil)
|
||||
continue
|
||||
}
|
||||
handle("stdout", line, nil)
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// expandHome rewrites $HOME, ${HOME}, or a leading ~/ in p to the
|
||||
// agent process's home directory. Other env-var references are left
|
||||
// untouched on purpose (operator-supplied paths shouldn't be able to
|
||||
// pick up arbitrary agent env values like $PATH or $RESTIC_PASSWORD).
|
||||
// Returns p unchanged if HOME can't be resolved.
|
||||
func expandHome(p string) string {
|
||||
if p == "" {
|
||||
return p
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || home == "" {
|
||||
return p
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(p, "$HOME/"):
|
||||
return filepath.Join(home, p[len("$HOME/"):])
|
||||
case p == "$HOME":
|
||||
return home
|
||||
case strings.HasPrefix(p, "${HOME}/"):
|
||||
return filepath.Join(home, p[len("${HOME}/"):])
|
||||
case p == "${HOME}":
|
||||
return home
|
||||
case strings.HasPrefix(p, "~/"):
|
||||
return filepath.Join(home, p[2:])
|
||||
case p == "~":
|
||||
return home
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// RunDiff executes `restic diff --json <a> <b>` and forwards every
|
||||
// line to handle as stdout. Restic emits per-line "change" objects
|
||||
// plus a final "statistics" object; we don't parse them server-side —
|
||||
// the operator reads the raw output on the live job log page.
|
||||
func (e Env) RunDiff(ctx context.Context, snapshotA, snapshotB string, handle LineHandler) error {
|
||||
if snapshotA == "" || snapshotB == "" {
|
||||
return fmt.Errorf("restic diff: two snapshot ids required")
|
||||
}
|
||||
cmd := e.resticCmd(ctx, "diff", "--json", snapshotA, snapshotB)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("restic diff: stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("restic diff: stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("restic diff: start: %w", err)
|
||||
}
|
||||
done := make(chan error, 2)
|
||||
// diff output isn't huge; pumpStderr-ish line-by-line forwarding
|
||||
// is fine.
|
||||
go func() {
|
||||
s := bufio.NewScanner(stdout)
|
||||
s.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for s.Scan() {
|
||||
if handle != nil {
|
||||
handle("stdout", s.Text(), nil)
|
||||
}
|
||||
}
|
||||
done <- s.Err()
|
||||
}()
|
||||
go func() { done <- pumpStderr(stderr, handle) }()
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-done; err != nil && handle != nil {
|
||||
handle("event", fmt.Sprintf("pump error: %v", err), nil)
|
||||
}
|
||||
}
|
||||
werr := cmd.Wait()
|
||||
if werr != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(werr, &ee) {
|
||||
return fmt.Errorf("restic diff: exit %d", ee.ExitCode())
|
||||
}
|
||||
return fmt.Errorf("restic diff: %w", werr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+35
-140
@@ -15,27 +15,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SupportsRestoreNoOwnership probes the running restic for the
|
||||
// `--no-ownership` flag on the `restore` subcommand. Some restic
|
||||
// builds (≥ 0.17 in theory; observed missing on a downstream 0.18.1)
|
||||
// do not expose it, so we ask the binary directly rather than
|
||||
// inferring from the version string. Empty `bin` or any failure to
|
||||
// run the help command returns false — the caller stays on the
|
||||
// conservative path of not adding the flag.
|
||||
func SupportsRestoreNoOwnership(ctx context.Context, bin string) bool {
|
||||
if bin == "" {
|
||||
return false
|
||||
}
|
||||
probeCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
out, err := exec.CommandContext(probeCtx, bin, "restore", "--help").CombinedOutput()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(out), "--no-ownership")
|
||||
}
|
||||
|
||||
// Locate resolves the path to the restic binary. Honour an explicit
|
||||
// Locate resolves the path to the restic binary. Honor an explicit
|
||||
// override if provided, else fall back to PATH.
|
||||
func Locate(override string) (string, error) {
|
||||
if override != "" {
|
||||
@@ -62,110 +42,11 @@ func Locate(override string) (string, error) {
|
||||
// in this package ever needs to *log* a URL, use RedactURL.
|
||||
type Env struct {
|
||||
Bin string // path to restic binary
|
||||
Version string // e.g. "0.17.1"; empty if unknown
|
||||
RepoURL string // RESTIC_REPOSITORY (no embedded creds)
|
||||
RepoUsername string // optional HTTP basic-auth user for rest: URLs
|
||||
RepoPassword string // doubles as RESTIC_PASSWORD and (for rest:) HTTP basic-auth password
|
||||
ExtraEnv map[string]string // any other RESTIC_* / passthrough
|
||||
WorkDir string // CWD; default = current
|
||||
|
||||
// SupportsRestoreNoOwnership records whether the running restic's
|
||||
// `restore --help` advertises the --no-ownership flag. The flag was
|
||||
// added in 0.17, but at least one downstream build of 0.18.1 still
|
||||
// rejects it ("unknown flag: --no-ownership") — version sniffing
|
||||
// proved unreliable, so the agent now probes for the actual flag at
|
||||
// startup (see internal/restic.SupportsRestoreNoOwnership) and
|
||||
// passes the resulting boolean down here.
|
||||
SupportsRestoreNoOwnership bool
|
||||
|
||||
// Bandwidth caps in KB/s. <=0 means "no cap" (omit the flag).
|
||||
// Emitted as restic global flags --limit-upload / --limit-download
|
||||
// before the subcommand on every invocation.
|
||||
LimitUploadKBps int
|
||||
LimitDownloadKBps int
|
||||
}
|
||||
|
||||
// AtLeastVersion reports whether e.Version >= the given major/minor.
|
||||
// Comparison is best-effort: empty / unparseable versions return false
|
||||
// (callers stay on the conservative path). Patch level is ignored.
|
||||
func (e Env) AtLeastVersion(major, minor int) bool {
|
||||
v := strings.TrimSpace(e.Version)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
parts := strings.SplitN(v, ".", 3)
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
maj, err1 := atoi(parts[0])
|
||||
min, err2 := atoi(parts[1])
|
||||
if err1 != nil || err2 != nil {
|
||||
return false
|
||||
}
|
||||
if maj != major {
|
||||
return maj > major
|
||||
}
|
||||
return min >= minor
|
||||
}
|
||||
|
||||
// atoi is strconv.Atoi without dragging the import into a file that
|
||||
// only needs it for one helper.
|
||||
func atoi(s string) (int, error) {
|
||||
n := 0
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty")
|
||||
}
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
return 0, fmt.Errorf("not a digit: %q", r)
|
||||
}
|
||||
n = n*10 + int(r-'0')
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// globalArgs returns restic's pre-subcommand global flags derived
|
||||
// from the Env. Currently just bandwidth caps.
|
||||
func (e Env) globalArgs() []string {
|
||||
var out []string
|
||||
if e.LimitUploadKBps > 0 {
|
||||
out = append(out, "--limit-upload", fmt.Sprintf("%d", e.LimitUploadKBps))
|
||||
}
|
||||
if e.LimitDownloadKBps > 0 {
|
||||
out = append(out, "--limit-download", fmt.Sprintf("%d", e.LimitDownloadKBps))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// resticCmd builds an exec.Cmd with bandwidth-limit globals prefixed
|
||||
// before the supplied subcommand args. Centralising this so every
|
||||
// command (backup/forget/prune/check/unlock/init/stats) honours
|
||||
// the caps without each call site having to remember.
|
||||
//
|
||||
// Cancellation: by default exec.CommandContext sends SIGKILL when
|
||||
// ctx is canceled, which leaves restic no chance to clean up its
|
||||
// repository lock. Override Cmd.Cancel to send SIGTERM first, and
|
||||
// set Cmd.WaitDelay so the process is force-killed if it doesn't
|
||||
// exit within five seconds. Restic responds to SIGTERM by removing
|
||||
// its lock file before exiting, which is what we want when an
|
||||
// operator cancels a long-running backup/restore from the UI.
|
||||
func (e Env) resticCmd(ctx context.Context, sub ...string) *exec.Cmd {
|
||||
args := append(e.globalArgs(), sub...)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, args...)
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
cmd.Cancel = func() error {
|
||||
// Cmd.Process is set after Start; Cancel only fires post-Start
|
||||
// so the nil check is defensive against the documented but
|
||||
// unlikely race. Signal returns ErrProcessDone if the process
|
||||
// already exited; that's not a problem here either.
|
||||
if cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
return cmd.Process.Signal(sigterm)
|
||||
}
|
||||
cmd.WaitDelay = 5 * time.Second
|
||||
return cmd
|
||||
}
|
||||
|
||||
// EventKind enumerates what we care about in restic's --json output
|
||||
@@ -211,7 +92,7 @@ type BackupSummary struct {
|
||||
}
|
||||
|
||||
// LineHandler receives every stdout/stderr line. event is non-nil
|
||||
// when the line is a recognised JSON status; raw always carries the
|
||||
// when the line is a recognized JSON status; raw always carries the
|
||||
// original text (so we can also tee to job_logs as `stdout`).
|
||||
type LineHandler func(stream string, raw string, event any)
|
||||
|
||||
@@ -229,7 +110,9 @@ func (e Env) RunBackup(ctx context.Context, paths, excludes, tags []string, hand
|
||||
}
|
||||
args = append(args, paths...)
|
||||
|
||||
cmd := e.resticCmd(ctx, args...)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, args...)
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
@@ -332,7 +215,9 @@ func (e Env) RunForget(ctx context.Context, groups []ForgetGroup, handle LineHan
|
||||
}
|
||||
args := []string{"forget", "--json", "--tag", g.Tag}
|
||||
args = append(args, g.Policy.args()...)
|
||||
cmd := e.resticCmd(ctx, args...)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, args...)
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
if err := runWithPump(cmd, handle); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -347,11 +232,13 @@ func (e Env) RunForget(ctx context.Context, groups []ForgetGroup, handle LineHan
|
||||
// <id> at <url>" on success, "config file already exists" on a
|
||||
// re-init attempt, etc.).
|
||||
func (e Env) RunInit(ctx context.Context, handle LineHandler) error {
|
||||
cmd := e.resticCmd(ctx, "init")
|
||||
cmd := exec.CommandContext(ctx, e.Bin, "init")
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
|
||||
// Sniff for "config file already exists" on stderr; if we see it
|
||||
// we'll treat the non-zero exit as a soft success — running init
|
||||
// against an already-initialised repo is a no-op semantically,
|
||||
// against an already-initialized repo is a no-op semantically,
|
||||
// not a failure. Wraps the caller's handle so the line still
|
||||
// gets streamed verbatim to the operator-facing log.
|
||||
alreadyInited := false
|
||||
@@ -367,7 +254,7 @@ func (e Env) RunInit(ctx context.Context, handle LineHandler) error {
|
||||
if err := runWithPump(cmd, sniff); err != nil {
|
||||
if alreadyInited {
|
||||
if handle != nil {
|
||||
handle("event", "repo already initialised — treating as success", nil)
|
||||
handle("event", "repo already initialized — treating as success", nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -385,7 +272,10 @@ func (e Env) RunInit(ctx context.Context, handle LineHandler) error {
|
||||
// support that's useful for our purposes). We tee everything to the
|
||||
// handler so the live log is the operator's progress bar.
|
||||
func (e Env) RunPrune(ctx context.Context, handle LineHandler) error {
|
||||
return runWithPump(e.resticCmd(ctx, "prune"), handle)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, "prune")
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
return runWithPump(cmd, handle)
|
||||
}
|
||||
|
||||
// runWithPump starts the configured cmd, fans stdout+stderr into
|
||||
@@ -423,7 +313,10 @@ func runWithPump(cmd *exec.Cmd, handle LineHandler) error {
|
||||
|
||||
// RunUnlock executes `restic unlock`. Returns nil on a clean exit.
|
||||
func (e Env) RunUnlock(ctx context.Context, handle LineHandler) error {
|
||||
return runWithPump(e.resticCmd(ctx, "unlock"), handle)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, "unlock")
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
return runWithPump(cmd, handle)
|
||||
}
|
||||
|
||||
// RepoStats mirrors `restic stats --json --mode raw-data` output.
|
||||
@@ -440,7 +333,9 @@ type RepoStats struct {
|
||||
// caller can still log it. Returns an error if no JSON-shaped line
|
||||
// arrived on stdout.
|
||||
func (e Env) RunStats(ctx context.Context, handle LineHandler) (*RepoStats, error) {
|
||||
cmd := e.resticCmd(ctx, "stats", "--json", "--mode", "raw-data")
|
||||
cmd := exec.CommandContext(ctx, e.Bin, "stats", "--json", "--mode", "raw-data")
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
var out *RepoStats
|
||||
capture := func(stream, line string, ev any) {
|
||||
if stream == "stdout" && strings.HasPrefix(line, "{") {
|
||||
@@ -463,7 +358,7 @@ func (e Env) RunStats(ctx context.Context, handle LineHandler) (*RepoStats, erro
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// CheckResult summarises a `restic check` invocation. LockPresent is
|
||||
// CheckResult summarizes a `restic check` invocation. LockPresent is
|
||||
// true if the stderr stream contained a stale-lock signal (caller is
|
||||
// expected to surface this in the UI so the operator can run unlock).
|
||||
// ErrorsFound is true if check exited with a non-zero status (errors
|
||||
@@ -475,7 +370,7 @@ type CheckResult struct {
|
||||
|
||||
// RunCheck executes `restic check` with optional --read-data-subset.
|
||||
// subsetPct of 0 omits the flag (full data check); >0 passes
|
||||
// --read-data-subset N%. Returns a CheckResult summarising what was
|
||||
// --read-data-subset N%. Returns a CheckResult summarizing what was
|
||||
// sniffed from stderr; the result is set even if check itself
|
||||
// returns an error (so the caller can persist last_check_status).
|
||||
func (e Env) RunCheck(ctx context.Context, subsetPct int, handle LineHandler) (CheckResult, error) {
|
||||
@@ -483,7 +378,9 @@ func (e Env) RunCheck(ctx context.Context, subsetPct int, handle LineHandler) (C
|
||||
if subsetPct > 0 {
|
||||
args = append(args, "--read-data-subset", fmt.Sprintf("%d%%", subsetPct))
|
||||
}
|
||||
cmd := e.resticCmd(ctx, args...)
|
||||
cmd := exec.CommandContext(ctx, e.Bin, args...)
|
||||
cmd.Env = e.envSlice()
|
||||
cmd.Dir = e.WorkDir
|
||||
|
||||
var res CheckResult
|
||||
sniff := func(stream, line string, ev any) {
|
||||
@@ -536,14 +433,12 @@ func pumpPlain(r io.Reader, stream string, handle LineHandler) error {
|
||||
// on one or the other for its cache dir; without it the command
|
||||
// fails before ever talking to the repo.
|
||||
//
|
||||
// Default to /var/lib/restic-manager. The unit no longer pins
|
||||
// ProtectHome=read-only (a backup tool needs to restore anywhere),
|
||||
// but the explicit HOME stays for two reasons: the parent's HOME
|
||||
// can be unset under unusual init shapes, and pinning the cache
|
||||
// under a known agent-owned dir keeps restic's metadata isolated
|
||||
// from the actual operator home dirs that the agent can now write
|
||||
// to. ExtraEnv overrides win for callers that want a different
|
||||
// cache location.
|
||||
// Default to /var/lib/restic-manager — that's in the systemd unit's
|
||||
// ReadWritePaths and survives ProtectHome=read-only. We do NOT fall
|
||||
// back to the parent's HOME env var: the agent runs as root with
|
||||
// HOME=/root, but ProtectHome makes /root read-only, so restic's
|
||||
// `mkdir /root/.cache/restic` fails. ExtraEnv overrides win for
|
||||
// callers that explicitly want a different cache location.
|
||||
func (e Env) envSlice() []string {
|
||||
home := "/var/lib/restic-manager"
|
||||
if h, ok := e.ExtraEnv["HOME"]; ok && h != "" {
|
||||
|
||||
@@ -174,43 +174,6 @@ func TestRunStatsErrorsWithoutJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandwidthLimitFlagsInjected(t *testing.T) {
|
||||
// Script echoes its argv to stdout. Each variant should produce
|
||||
// the right --limit-* flags before the subcommand.
|
||||
cases := []struct {
|
||||
name string
|
||||
env Env
|
||||
want []string
|
||||
}{
|
||||
{"both caps", Env{LimitUploadKBps: 1024, LimitDownloadKBps: 512}, []string{"--limit-upload 1024", "--limit-download 512"}},
|
||||
{"only upload", Env{LimitUploadKBps: 256}, []string{"--limit-upload 256"}},
|
||||
{"zero means omit", Env{LimitUploadKBps: 0, LimitDownloadKBps: 0}, nil},
|
||||
{"negative means omit", Env{LimitUploadKBps: -1}, nil},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
bin := setupScriptBin(t, `echo "$@"`)
|
||||
env := c.env
|
||||
env.Bin = bin
|
||||
lines, h := captureLines()
|
||||
if err := env.RunUnlock(context.Background(), h); err != nil {
|
||||
t.Fatalf("RunUnlock: %v", err)
|
||||
}
|
||||
joined := strings.Join(*lines, "\n")
|
||||
for _, want := range c.want {
|
||||
if !strings.Contains(joined, want) {
|
||||
t.Fatalf("want %q in argv; got: %s", want, joined)
|
||||
}
|
||||
}
|
||||
if len(c.want) == 0 {
|
||||
if strings.Contains(joined, "--limit-upload") || strings.Contains(joined, "--limit-download") {
|
||||
t.Fatalf("expected no limit flags; got: %s", joined)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStatsZeroSnapshots(t *testing.T) {
|
||||
// Confirms RunStats succeeds and returns a valid *RepoStats when the
|
||||
// repo has no snapshots (snapshots_count=0). A regression that
|
||||
|
||||
@@ -13,11 +13,9 @@ import (
|
||||
// decode only the fields we project to the server; restic's full
|
||||
// shape has more (parent, tree, program version) that we don't need.
|
||||
//
|
||||
// Summary is only populated by restic 0.17+ (which embeds the backup
|
||||
// summary inside each snapshot record). Older clients leave it nil
|
||||
// and the agent reports zero size/file-count — the UI degrades to
|
||||
// "—" and the column headers carry a tooltip explaining the version
|
||||
// requirement (see web/templates/pages/host_detail.html).
|
||||
// Summary is only populated by restic 0.16+ (which embeds the backup
|
||||
// summary inside each snapshot). Older clients leave it nil and the
|
||||
// agent reports zero size/file-count — the UI degrades to "—".
|
||||
type Snapshot struct {
|
||||
ID string `json:"id"`
|
||||
ShortID string `json:"short_id"`
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package restic
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnvAtLeastVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
ver string
|
||||
major int
|
||||
minor int
|
||||
want bool
|
||||
shortDesc string
|
||||
}{
|
||||
{"0.17.0", 0, 17, true, "exact match"},
|
||||
{"0.17.1", 0, 17, true, "patch above"},
|
||||
{"0.18.0", 0, 17, true, "minor above"},
|
||||
{"1.0.0", 0, 17, true, "major above"},
|
||||
{"0.16.4", 0, 17, false, "minor below"},
|
||||
{"0.16", 0, 17, false, "two-part minor below"},
|
||||
{"", 0, 17, false, "empty"},
|
||||
{"v0.17", 0, 17, false, "prefixed v rejected"},
|
||||
{"unknown", 0, 17, false, "non-numeric rejected"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := Env{Version: c.ver}.AtLeastVersion(c.major, c.minor)
|
||||
if got != c.want {
|
||||
t.Errorf("AtLeastVersion(%q, %d, %d): got %v want %v · %s",
|
||||
c.ver, c.major, c.minor, got, c.want, c.shortDesc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandHome(t *testing.T) {
|
||||
// Not parallel — t.Setenv on HOME would race with sibling tests.
|
||||
tmp := t.TempDir()
|
||||
t.Setenv("HOME", tmp)
|
||||
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"$HOME/rm-restore/job-1/", filepath.Join(tmp, "rm-restore/job-1")},
|
||||
{"${HOME}/rm-restore/job-2/", filepath.Join(tmp, "rm-restore/job-2")},
|
||||
{"~/rm-restore/job-3/", filepath.Join(tmp, "rm-restore/job-3")},
|
||||
{"$HOME", tmp},
|
||||
{"~", tmp},
|
||||
{"/var/lib/x/y", "/var/lib/x/y"}, // absolute path passes through
|
||||
{"", ""},
|
||||
{"$PATH/foo", "$PATH/foo"}, // other env vars not expanded
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := expandHome(c.in)
|
||||
if got != c.want {
|
||||
t.Errorf("expandHome(%q): got %q want %q", c.in, got, c.want)
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity: an absolute path always passes through regardless of HOME.
|
||||
if got := expandHome("/abs"); got != "/abs" {
|
||||
t.Errorf("expandHome(/abs): got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -30,17 +30,7 @@ type Config struct {
|
||||
// Defaults to true. Set RM_COOKIE_SECURE=false only for local HTTP
|
||||
// testing — production deployments are always behind a TLS proxy
|
||||
// and the cookie must be Secure.
|
||||
CookieSecure bool `yaml:"cookie_secure"`
|
||||
OIDCRaw *OIDCConfig `yaml:"oidc"`
|
||||
OIDC *OIDCConfig `yaml:"-"`
|
||||
|
||||
// BundledAssetsDir is the read-only path inside the image that
|
||||
// holds agent binaries (under agent-binaries/) and install
|
||||
// scripts (under install/). The /agent/binary and /install/*
|
||||
// handlers fall back here when the file is not present in
|
||||
// DataDir. Source-build deployments can override via
|
||||
// RM_BUNDLED_ASSETS_DIR.
|
||||
BundledAssetsDir string `yaml:"bundled_assets_dir"`
|
||||
CookieSecure bool `yaml:"cookie_secure"`
|
||||
}
|
||||
|
||||
// Load resolves config in this order:
|
||||
@@ -52,10 +42,9 @@ type Config struct {
|
||||
// safe to start.
|
||||
func Load(yamlPath string) (Config, error) {
|
||||
c := Config{
|
||||
Listen: ":8080",
|
||||
DataDir: "/data",
|
||||
CookieSecure: true,
|
||||
BundledAssetsDir: "/opt/restic-manager/dist",
|
||||
Listen: ":8080",
|
||||
DataDir: "/data",
|
||||
CookieSecure: true,
|
||||
}
|
||||
|
||||
if yamlPath != "" {
|
||||
@@ -90,9 +79,6 @@ func Load(yamlPath string) (Config, error) {
|
||||
c.CookieSecure = true
|
||||
}
|
||||
}
|
||||
if v, ok := os.LookupEnv("RM_BUNDLED_ASSETS_DIR"); ok {
|
||||
c.BundledAssetsDir = v
|
||||
}
|
||||
if v, ok := os.LookupEnv("RM_TRUSTED_PROXY"); ok {
|
||||
// Comma-separated CIDRs; allow whitespace for readability.
|
||||
parts := strings.Split(v, ",")
|
||||
@@ -105,16 +91,6 @@ func Load(yamlPath string) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var rawOIDC OIDCConfig
|
||||
if c.OIDCRaw != nil {
|
||||
rawOIDC = *c.OIDCRaw
|
||||
}
|
||||
oidc, err := loadOIDC(envSnapshot(), rawOIDC)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
c.OIDC = oidc
|
||||
|
||||
return c, c.validate()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
// internal/server/config/oidc.go — OIDC subsection of the server
|
||||
// config. Disabled when oidc.issuer is empty or absent.
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// OIDCConfig is the OIDC sub-block. The struct doubles as YAML schema;
|
||||
// loadOIDC applies env overlays on top and fills defaults.
|
||||
type OIDCConfig struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
DisplayName string `yaml:"display_name"`
|
||||
Scopes []string `yaml:"scopes"`
|
||||
RoleClaim string `yaml:"role_claim"`
|
||||
RoleMapping map[string]string `yaml:"role_mapping"`
|
||||
RedirectURL string `yaml:"redirect_url"`
|
||||
}
|
||||
|
||||
// loadOIDC merges YAML + env, applies defaults, validates. Returns
|
||||
// nil + nil when OIDC is disabled (issuer empty after merge); a
|
||||
// non-nil OIDCConfig means the caller should wire OIDC.
|
||||
//
|
||||
// Env vars (override YAML when set):
|
||||
//
|
||||
// RM_OIDC_ISSUER, RM_OIDC_CLIENT_ID, RM_OIDC_CLIENT_SECRET,
|
||||
// RM_OIDC_CLIENT_SECRET_FILE, RM_OIDC_DISPLAY_NAME,
|
||||
// RM_OIDC_REDIRECT_URL.
|
||||
//
|
||||
// envs is passed in (rather than read with os.LookupEnv) so unit
|
||||
// tests can supply a fake env map.
|
||||
func loadOIDC(envs map[string]string, yaml OIDCConfig) (*OIDCConfig, error) {
|
||||
c := yaml
|
||||
if v, ok := envs["RM_OIDC_ISSUER"]; ok {
|
||||
c.Issuer = v
|
||||
}
|
||||
if v, ok := envs["RM_OIDC_CLIENT_ID"]; ok {
|
||||
c.ClientID = v
|
||||
}
|
||||
if v, ok := envs["RM_OIDC_CLIENT_SECRET"]; ok {
|
||||
c.ClientSecret = v
|
||||
}
|
||||
if v, ok := envs["RM_OIDC_CLIENT_SECRET_FILE"]; ok && v != "" {
|
||||
body, err := os.ReadFile(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: oidc client_secret_file: %w", err)
|
||||
}
|
||||
c.ClientSecret = string(body)
|
||||
}
|
||||
if v, ok := envs["RM_OIDC_DISPLAY_NAME"]; ok {
|
||||
c.DisplayName = v
|
||||
}
|
||||
if v, ok := envs["RM_OIDC_REDIRECT_URL"]; ok {
|
||||
c.RedirectURL = v
|
||||
}
|
||||
|
||||
if c.Issuer == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.ClientID == "" {
|
||||
return nil, errors.New("config: oidc.client_id required when issuer is set")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return nil, errors.New("config: oidc.client_secret required when issuer is set")
|
||||
}
|
||||
if len(c.RoleMapping) == 0 {
|
||||
return nil, errors.New("config: oidc.role_mapping must have at least one entry")
|
||||
}
|
||||
|
||||
if c.DisplayName == "" {
|
||||
c.DisplayName = "SSO"
|
||||
}
|
||||
if c.RoleClaim == "" {
|
||||
c.RoleClaim = "groups"
|
||||
}
|
||||
if len(c.Scopes) == 0 {
|
||||
c.Scopes = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// envSnapshot reads the OIDC env vars into a map. Lets the production
|
||||
// loadOIDC call site stay env-driven while tests pass an explicit
|
||||
// map.
|
||||
func envSnapshot() map[string]string {
|
||||
keys := []string{
|
||||
"RM_OIDC_ISSUER", "RM_OIDC_CLIENT_ID", "RM_OIDC_CLIENT_SECRET",
|
||||
"RM_OIDC_CLIENT_SECRET_FILE", "RM_OIDC_DISPLAY_NAME",
|
||||
"RM_OIDC_REDIRECT_URL",
|
||||
}
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, k := range keys {
|
||||
if v, ok := os.LookupEnv(k); ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOIDCParseDisabledWhenIssuerEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
c, err := loadOIDC(map[string]string{}, OIDCConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if c != nil {
|
||||
t.Errorf("expected nil OIDC config when issuer empty; got %+v", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCRejectMissingClientID(t *testing.T) {
|
||||
t.Parallel()
|
||||
yaml := OIDCConfig{Issuer: "https://x", ClientSecret: "s"}
|
||||
if _, err := loadOIDC(map[string]string{}, yaml); err == nil {
|
||||
t.Error("expected error for missing client_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCRejectMissingClientSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
yaml := OIDCConfig{Issuer: "https://x", ClientID: "rm"}
|
||||
if _, err := loadOIDC(map[string]string{}, yaml); err == nil {
|
||||
t.Error("expected error for missing client_secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCDefaultsApplied(t *testing.T) {
|
||||
t.Parallel()
|
||||
yaml := OIDCConfig{
|
||||
Issuer: "https://x", ClientID: "rm", ClientSecret: "s",
|
||||
RoleMapping: map[string]string{"a": "admin"},
|
||||
}
|
||||
c, err := loadOIDC(map[string]string{}, yaml)
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if c.RoleClaim != "groups" {
|
||||
t.Errorf("role_claim default: got %q want groups", c.RoleClaim)
|
||||
}
|
||||
if c.DisplayName != "SSO" {
|
||||
t.Errorf("display_name default: got %q want SSO", c.DisplayName)
|
||||
}
|
||||
wantScopes := []string{"openid", "profile", "email", "groups"}
|
||||
if len(c.Scopes) != len(wantScopes) {
|
||||
t.Errorf("scopes default: got %v want %v", c.Scopes, wantScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCEnvOverrides(t *testing.T) {
|
||||
t.Parallel()
|
||||
yaml := OIDCConfig{
|
||||
Issuer: "https://from-yaml", ClientID: "yaml-id", ClientSecret: "yaml-secret",
|
||||
RoleMapping: map[string]string{"x": "admin"},
|
||||
}
|
||||
envs := map[string]string{
|
||||
"RM_OIDC_ISSUER": "https://from-env",
|
||||
"RM_OIDC_CLIENT_ID": "env-id",
|
||||
"RM_OIDC_CLIENT_SECRET": "env-secret",
|
||||
}
|
||||
c, err := loadOIDC(envs, yaml)
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if c.Issuer != "https://from-env" || c.ClientID != "env-id" || c.ClientSecret != "env-secret" {
|
||||
t.Errorf("env override: got %+v", c)
|
||||
}
|
||||
}
|
||||
@@ -11,23 +11,19 @@ import (
|
||||
)
|
||||
|
||||
// agent_assets.go serves the agent binary (one per OS/arch) and the
|
||||
// install scripts. Lookup is dual-path:
|
||||
//
|
||||
// 1. <DataDir>/agent-binaries/<name> (or <DataDir>/install/<name>) —
|
||||
// operator-managed override; lets the operator hot-patch a
|
||||
// pre-release agent without rebuilding the server image.
|
||||
// 2. <BundledAssetsDir>/agent-binaries/<name> — read-only, baked
|
||||
// into the server image at build time (P5-03). This is what
|
||||
// makes a fresh container Just Work without first-run staging.
|
||||
// install scripts. The binaries live under <DataDir>/agent-binaries/,
|
||||
// laid down by the release pipeline (or copied by hand for now).
|
||||
// The install scripts live in <DataDir>/install/ alongside the
|
||||
// systemd unit.
|
||||
//
|
||||
// Both endpoints are intentionally unauthenticated: the install
|
||||
// payload is unprivileged on its own — it's the one-time enrollment
|
||||
// token that grants access. Anyone can pull the binary; only
|
||||
// someone with a valid token can use it productively.
|
||||
//
|
||||
// P1-31: signed-binary verification is deferred. The image is the
|
||||
// unit of trust; pull-by-digest is the verification primitive.
|
||||
// Future work bumps standalone-binary delivery to minisign/cosign.
|
||||
// P1-31: signed-binary verification is deferred. Today we serve
|
||||
// whatever the operator dropped on disk. Future work bumps this to
|
||||
// minisign/cosign signed bundles.
|
||||
|
||||
// installAssetsRoutes adds /agent/binary and /install/* to r.
|
||||
func (s *Server) handleAgentBinary(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
@@ -49,8 +45,8 @@ func (s *Server) handleAgentBinary(w stdhttp.ResponseWriter, r *stdhttp.Request)
|
||||
ext = ".exe"
|
||||
}
|
||||
name := fmt.Sprintf("restic-manager-agent-%s-%s%s", osTag, archTag, ext)
|
||||
path, ok := s.resolveBundledAsset("agent-binaries", name)
|
||||
if !ok {
|
||||
path := filepath.Join(s.deps.Cfg.DataDir, "agent-binaries", name)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "binary_not_published",
|
||||
fmt.Sprintf("agent binary for %s/%s not published on this server", osTag, archTag))
|
||||
return
|
||||
@@ -61,41 +57,21 @@ func (s *Server) handleAgentBinary(w stdhttp.ResponseWriter, r *stdhttp.Request)
|
||||
}
|
||||
|
||||
func (s *Server) handleInstallAsset(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
// chi's TrimPrefix-like behaviour: r.URL.Path is "/install/<file>".
|
||||
// chi's TrimPrefix-like behavior: r.URL.Path is "/install/<file>".
|
||||
rel := strings.TrimPrefix(r.URL.Path, "/install/")
|
||||
// Reject any path traversal — must be a flat filename.
|
||||
if rel == "" || strings.ContainsAny(rel, "/\\") {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "bad_path", "")
|
||||
return
|
||||
}
|
||||
path, ok := s.resolveBundledAsset("install", rel)
|
||||
if !ok {
|
||||
path := filepath.Join(s.deps.Cfg.DataDir, "install", rel)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "not_found", "")
|
||||
return
|
||||
}
|
||||
stdhttp.ServeFile(w, r, path)
|
||||
}
|
||||
|
||||
// resolveBundledAsset looks up an asset by (subdir, name). DataDir
|
||||
// wins so an operator can override the image-baked copy by dropping
|
||||
// a file into <DataDir>/<subdir>/<name>. If neither path resolves,
|
||||
// returns ("", false).
|
||||
func (s *Server) resolveBundledAsset(subdir, name string) (string, bool) {
|
||||
candidates := []string{
|
||||
filepath.Join(s.deps.Cfg.DataDir, subdir, name),
|
||||
}
|
||||
if s.deps.Cfg.BundledAssetsDir != "" {
|
||||
candidates = append(candidates,
|
||||
filepath.Join(s.deps.Cfg.BundledAssetsDir, subdir, name))
|
||||
}
|
||||
for _, p := range candidates {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func validOS(s string) bool {
|
||||
switch api.HostOS(s) {
|
||||
case api.OSLinux, api.OSWindows:
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
stdhttp "net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// newAssetsTestServer is a minimal scaffold for the /agent/binary and
|
||||
// /install/* handlers. Two roots: one acts as DataDir, the other as
|
||||
// the image-baked BundledAssetsDir. Either or both may be empty.
|
||||
func newAssetsTestServer(t *testing.T, populate func(dataDir, bundleDir string)) string {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
dataDir := filepath.Join(root, "data")
|
||||
bundleDir := filepath.Join(root, "dist")
|
||||
for _, d := range []string{
|
||||
filepath.Join(dataDir, "agent-binaries"),
|
||||
filepath.Join(dataDir, "install"),
|
||||
filepath.Join(bundleDir, "agent-binaries"),
|
||||
filepath.Join(bundleDir, "install"),
|
||||
} {
|
||||
if err := os.MkdirAll(d, 0o755); err != nil {
|
||||
t.Fatalf("mkdir: %v", err)
|
||||
}
|
||||
}
|
||||
if populate != nil {
|
||||
populate(dataDir, bundleDir)
|
||||
}
|
||||
|
||||
st, err := store.Open(context.Background(), filepath.Join(root, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
|
||||
keyPath := filepath.Join(root, "secret.key")
|
||||
_ = crypto.GenerateKeyFile(keyPath)
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
|
||||
deps := Deps{
|
||||
Cfg: config.Config{
|
||||
Listen: ":0",
|
||||
DataDir: dataDir,
|
||||
SecretKeyFile: keyPath,
|
||||
BundledAssetsDir: bundleDir,
|
||||
},
|
||||
Store: st,
|
||||
AEAD: aead,
|
||||
Hub: ws.NewHub(),
|
||||
BootstrapToken: "test-token",
|
||||
}
|
||||
s := New(deps)
|
||||
ts := httptest.NewServer(s.srv.Handler)
|
||||
t.Cleanup(ts.Close)
|
||||
return ts.URL
|
||||
}
|
||||
|
||||
func writeFile(t *testing.T, path string, body []byte) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, body, 0o644); err != nil {
|
||||
t.Fatalf("write %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func get(t *testing.T, url string) (int, []byte) {
|
||||
t.Helper()
|
||||
res, err := stdhttp.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("GET %s: %v", url, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
return res.StatusCode, body
|
||||
}
|
||||
|
||||
func TestAgentBinary_DataDirHit(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, func(dataDir, _ string) {
|
||||
writeFile(t, filepath.Join(dataDir, "agent-binaries", "restic-manager-agent-linux-amd64"),
|
||||
[]byte("from-datadir"))
|
||||
})
|
||||
code, body := get(t, url+"/agent/binary?os=linux&arch=amd64")
|
||||
if code != 200 || string(body) != "from-datadir" {
|
||||
t.Fatalf("got %d %q", code, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentBinary_BundleFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, func(_, bundleDir string) {
|
||||
writeFile(t, filepath.Join(bundleDir, "agent-binaries", "restic-manager-agent-linux-amd64"),
|
||||
[]byte("from-bundle"))
|
||||
})
|
||||
code, body := get(t, url+"/agent/binary?os=linux&arch=amd64")
|
||||
if code != 200 || string(body) != "from-bundle" {
|
||||
t.Fatalf("got %d %q", code, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentBinary_DataDirShadowsBundle(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, func(dataDir, bundleDir string) {
|
||||
writeFile(t, filepath.Join(dataDir, "agent-binaries", "restic-manager-agent-linux-amd64"),
|
||||
[]byte("from-datadir"))
|
||||
writeFile(t, filepath.Join(bundleDir, "agent-binaries", "restic-manager-agent-linux-amd64"),
|
||||
[]byte("from-bundle"))
|
||||
})
|
||||
code, body := get(t, url+"/agent/binary?os=linux&arch=amd64")
|
||||
if code != 200 || string(body) != "from-datadir" {
|
||||
t.Fatalf("operator override should win: got %d %q", code, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentBinary_BothMiss(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, nil)
|
||||
code, _ := get(t, url+"/agent/binary?os=linux&arch=amd64")
|
||||
if code != 404 {
|
||||
t.Fatalf("expected 404, got %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentBinary_WindowsNameHasExe(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, func(_, bundleDir string) {
|
||||
writeFile(t, filepath.Join(bundleDir, "agent-binaries", "restic-manager-agent-windows-amd64.exe"),
|
||||
[]byte("win"))
|
||||
})
|
||||
code, body := get(t, url+"/agent/binary?os=windows&arch=amd64")
|
||||
if code != 200 || string(body) != "win" {
|
||||
t.Fatalf("got %d %q", code, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallAsset_BundleFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, func(_, bundleDir string) {
|
||||
writeFile(t, filepath.Join(bundleDir, "install", "install.sh"), []byte("#!/bin/sh\n"))
|
||||
})
|
||||
code, body := get(t, url+"/install/install.sh")
|
||||
if code != 200 || string(body) != "#!/bin/sh\n" {
|
||||
t.Fatalf("got %d %q", code, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallAsset_PathTraversalRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
url := newAssetsTestServer(t, nil)
|
||||
// chi will normalise some traversal attempts, but the handler
|
||||
// also rejects any rel containing a slash or backslash. The
|
||||
// path component of the URL after /install/ is the rel.
|
||||
code, _ := get(t, url+"/install/..%2fpasswd")
|
||||
if code == 200 {
|
||||
t.Fatalf("traversal should not return 200")
|
||||
}
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
// announce.go — POST /api/agents/announce: agent without a token
|
||||
// announces itself with a freshly-minted Ed25519 public key, server
|
||||
// stashes a pending_hosts row, admin compares fingerprints in the
|
||||
// UI before accepting (P2-18a).
|
||||
//
|
||||
// Guards (per spec):
|
||||
// - Per-source-IP token-bucket rate limit (10/min).
|
||||
// - Global cap of 100 in-flight pending rows; further announces
|
||||
// get 503 with a hint.
|
||||
// - Public key must be exactly 32 bytes (Ed25519). Anything else
|
||||
// 400-rejected.
|
||||
//
|
||||
// Hostname collisions are NOT rejected — multiple announces with
|
||||
// the same hostname can be legitimate (re-running install on the
|
||||
// same box). The UI flags collisions for the admin to disambiguate.
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Tunables — exposed as vars so tests can lower them. Defaults mirror
|
||||
// the spec's recommendations.
|
||||
var (
|
||||
announceMaxPerMin = 10
|
||||
announceGlobalCap = 100
|
||||
)
|
||||
|
||||
// announceRequest is the wire shape POST /api/agents/announce takes.
|
||||
// PublicKey is base64-std (no padding strip — stdlib decoder is
|
||||
// lenient on padding for both forms).
|
||||
type announceRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
AgentVersion string `json:"agent_version"`
|
||||
ResticVersion string `json:"restic_version"`
|
||||
PublicKey string `json:"public_key"` // base64
|
||||
}
|
||||
|
||||
// announceResponse is what the agent gets back. Fingerprint is the
|
||||
// canonical "SHA256:hex" the operator compares against the UI.
|
||||
// HostnameCollision warns the install script that another pending
|
||||
// row already uses the same hostname.
|
||||
type announceResponse struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
HostnameCollision bool `json:"hostname_collision"`
|
||||
}
|
||||
|
||||
// rateBucket is a tiny per-IP token-bucket. last is the timestamp of
|
||||
// the most recent refill; tokens is the current bucket level. Refill
|
||||
// rate is announceMaxPerMin tokens/minute, burst = announceMaxPerMin.
|
||||
type rateBucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
// announceLimiter holds one bucket per source IP. Buckets are reaped
|
||||
// lazily by a tiny grace period — we don't need true LRU cleanup
|
||||
// because the bucket count is bounded by unique IPs in any given
|
||||
// few minutes (small).
|
||||
type announceLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*rateBucket
|
||||
}
|
||||
|
||||
func newAnnounceLimiter() *announceLimiter {
|
||||
return &announceLimiter{buckets: map[string]*rateBucket{}}
|
||||
}
|
||||
|
||||
// allow returns true and consumes a token if the IP's bucket has at
|
||||
// least one token, else returns false. Capacity = announceMaxPerMin.
|
||||
func (l *announceLimiter) allow(ip string, now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
cap := float64(announceMaxPerMin)
|
||||
b, ok := l.buckets[ip]
|
||||
if !ok {
|
||||
b = &rateBucket{tokens: cap, last: now}
|
||||
l.buckets[ip] = b
|
||||
}
|
||||
// Refill at cap tokens per minute.
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
if elapsed > 0 {
|
||||
b.tokens += (elapsed / 60.0) * cap
|
||||
if b.tokens > cap {
|
||||
b.tokens = cap
|
||||
}
|
||||
b.last = now
|
||||
}
|
||||
if b.tokens < 1.0 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAnnounce is the public POST handler. Public — no auth.
|
||||
func (s *Server) handleAnnounce(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Rate limit by source IP. Strip port — the limit is per host,
|
||||
// not per outbound source port.
|
||||
ip := remoteIP(r)
|
||||
if !s.announceRL.allow(ip, now) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
writeJSONError(w, stdhttp.StatusTooManyRequests, "rate_limited",
|
||||
"too many announces from this source; retry in a minute")
|
||||
return
|
||||
}
|
||||
|
||||
var req announceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
if req.Hostname == "" || req.OS == "" || req.Arch == "" || req.PublicKey == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||
"hostname, os, arch, public_key are required")
|
||||
return
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
|
||||
if err != nil {
|
||||
// Try URL-safe / no-padding flavours before giving up.
|
||||
if k2, e2 := base64.RawStdEncoding.DecodeString(req.PublicKey); e2 == nil {
|
||||
keyBytes = k2
|
||||
} else {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be base64")
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(keyBytes) != ed25519.PublicKeySize {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_public_key",
|
||||
"public_key must be 32 bytes (Ed25519)")
|
||||
return
|
||||
}
|
||||
|
||||
// Global cap (cheap query — index on expires_at).
|
||||
count, err := s.deps.Store.CountPendingHosts(r.Context(), now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if count >= announceGlobalCap {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "pending_cap_reached",
|
||||
"too many in-flight pending hosts; ask an admin to clear the queue")
|
||||
return
|
||||
}
|
||||
|
||||
// Hostname collision flag (informational).
|
||||
colls, err := s.deps.Store.CountPendingHostsByHostname(r.Context(), req.Hostname, now)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ph := &store.PendingHost{
|
||||
ID: ulid.Make().String(),
|
||||
Hostname: req.Hostname,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
AgentVersion: req.AgentVersion,
|
||||
ResticVersion: req.ResticVersion,
|
||||
PublicKey: keyBytes,
|
||||
Fingerprint: store.FingerprintForKey(keyBytes),
|
||||
AnnouncedFromIP: ip,
|
||||
FirstSeenAt: now,
|
||||
LastSeenAt: now,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
if err := s.deps.Store.CreatePendingHost(r.Context(), ph); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, stdhttp.StatusOK, announceResponse{
|
||||
PendingID: ph.ID,
|
||||
Fingerprint: ph.Fingerprint,
|
||||
HostnameCollision: colls > 0,
|
||||
})
|
||||
}
|
||||
|
||||
// remoteIP returns r.RemoteAddr stripped of any :port suffix, plus
|
||||
// the X-Forwarded-For chain's first hop when behind a trusted proxy
|
||||
// (RM_TRUSTED_PROXY in the deployment doc). Trust-proxy lookup
|
||||
// matches the framework's existing behaviour elsewhere.
|
||||
func remoteIP(r *stdhttp.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the chain (closest to the original
|
||||
// client) — same convention chi uses. Trim whitespace.
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
addr := r.RemoteAddr
|
||||
if i := strings.LastIndex(addr, ":"); i >= 0 {
|
||||
return addr[:i]
|
||||
}
|
||||
return addr
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
// announce_test.go — covers POST /api/agents/announce: happy path,
|
||||
// invalid public key, hostname collision flag, rate limit, global
|
||||
// cap (P2-18a).
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func newKeypair(t *testing.T) ed25519.PublicKey {
|
||||
t.Helper()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ed25519: %v", err)
|
||||
}
|
||||
return pub
|
||||
}
|
||||
|
||||
func postAnnounce(t *testing.T, url string, req announceRequest) (status int, header stdhttp.Header, body []byte) {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(req)
|
||||
r, _ := stdhttp.NewRequest("POST", url+"/api/agents/announce", bytes.NewReader(b))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
res, err := stdhttp.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
out := make([]byte, 4096)
|
||||
n, _ := res.Body.Read(out)
|
||||
return res.StatusCode, res.Header, out[:n]
|
||||
}
|
||||
|
||||
func TestAnnounceHappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, st := newTestServerWithHub(t)
|
||||
pub := newKeypair(t)
|
||||
status, _, body := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "alice", OS: "linux", Arch: "amd64",
|
||||
AgentVersion: "1.0", ResticVersion: "0.17",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
})
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("status: %d body=%s", status, body)
|
||||
}
|
||||
var ar announceResponse
|
||||
if err := json.Unmarshal(body, &ar); err != nil {
|
||||
t.Fatalf("unmarshal: %v body=%s", err, body)
|
||||
}
|
||||
if ar.PendingID == "" {
|
||||
t.Fatal("missing pending_id")
|
||||
}
|
||||
if !strings.HasPrefix(ar.Fingerprint, "SHA256:") {
|
||||
t.Fatalf("fingerprint shape: %q", ar.Fingerprint)
|
||||
}
|
||||
if ar.HostnameCollision {
|
||||
t.Fatal("first announce shouldn't be a collision")
|
||||
}
|
||||
// Row exists in the store.
|
||||
if _, err := st.GetPendingHost(context.Background(), ar.PendingID); err != nil {
|
||||
t.Fatalf("pending row missing: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRejectsBadKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "x", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString([]byte("too-short")),
|
||||
})
|
||||
if status != stdhttp.StatusBadRequest {
|
||||
t.Fatalf("status: got %d, want 400", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceHostnameCollisionFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
pub1 := newKeypair(t)
|
||||
pub2 := newKeypair(t)
|
||||
_, _, _ = postAnnounce(t, url, announceRequest{
|
||||
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub1),
|
||||
})
|
||||
status, _, body := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "dup-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub2),
|
||||
})
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("status: %d", status)
|
||||
}
|
||||
var ar announceResponse
|
||||
_ = json.Unmarshal(body, &ar)
|
||||
if !ar.HostnameCollision {
|
||||
t.Fatal("expected hostname_collision=true on second announce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRateLimit(t *testing.T) {
|
||||
// Not t.Parallel — mutates the package-level announceMaxPerMin
|
||||
// var, which would otherwise race other announce tests.
|
||||
_, url, _ := newTestServerWithHub(t)
|
||||
prev := announceMaxPerMin
|
||||
announceMaxPerMin = 2
|
||||
t.Cleanup(func() { announceMaxPerMin = prev })
|
||||
|
||||
pub := newKeypair(t)
|
||||
body := announceRequest{
|
||||
Hostname: "rl-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
status, _, _ := postAnnounce(t, url, body)
|
||||
if status != stdhttp.StatusOK {
|
||||
t.Fatalf("call %d: status %d", i, status)
|
||||
}
|
||||
}
|
||||
status, _, _ := postAnnounce(t, url, body)
|
||||
if status != stdhttp.StatusTooManyRequests {
|
||||
t.Fatalf("3rd call: want 429, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceGlobalCap(t *testing.T) {
|
||||
// Not t.Parallel — mutates the package-level announceGlobalCap.
|
||||
_, url, st := newTestServerWithHub(t)
|
||||
prev := announceGlobalCap
|
||||
announceGlobalCap = 1
|
||||
t.Cleanup(func() { announceGlobalCap = prev })
|
||||
|
||||
// Pre-seed one row directly via the store so the cap is hit.
|
||||
pub := newKeypair(t)
|
||||
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
|
||||
ID: ulid.Make().String(), Hostname: "x", OS: "linux", Arch: "amd64",
|
||||
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
|
||||
AnnouncedFromIP: "127.0.0.1",
|
||||
FirstSeenAt: time.Now().UTC(),
|
||||
LastSeenAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
status, _, _ := postAnnounce(t, url, announceRequest{
|
||||
Hostname: "next", OS: "linux", Arch: "amd64",
|
||||
PublicKey: base64.StdEncoding.EncodeToString(newKeypair(t)),
|
||||
})
|
||||
if status != stdhttp.StatusServiceUnavailable {
|
||||
t.Fatalf("want 503 over cap, got %d", status)
|
||||
}
|
||||
}
|
||||
@@ -1,391 +0,0 @@
|
||||
// api_users.go — JSON handlers for the user-management surface.
|
||||
//
|
||||
// All endpoints in this file are admin-only; gating happens at the
|
||||
// route-mount site (server.go's admin band).
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
type listUsersResponse struct {
|
||||
Users []apiUser `json:"users"`
|
||||
}
|
||||
|
||||
type apiUser struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
Disabled bool `json:"disabled"`
|
||||
MustChangePassword bool `json:"must_change_password"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLoginAt *string `json:"last_login_at,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUsersList(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
users, err := s.deps.Store.ListUsers(r.Context(), store.UserSort{})
|
||||
if err != nil {
|
||||
slog.Error("api users: list", "err", err)
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
out := make([]apiUser, len(users))
|
||||
for i, u := range users {
|
||||
var lastLogin *string
|
||||
if u.LastLoginAt != nil {
|
||||
s := u.LastLoginAt.UTC().Format("2006-01-02T15:04:05Z")
|
||||
lastLogin = &s
|
||||
}
|
||||
out[i] = apiUser{
|
||||
ID: u.ID, Username: u.Username, Role: string(u.Role),
|
||||
Email: u.Email, Disabled: u.DisabledAt != nil,
|
||||
MustChangePassword: u.MustChangePassword,
|
||||
CreatedAt: u.CreatedAt.UTC().Format("2006-01-02T15:04:05Z"),
|
||||
LastLoginAt: lastLogin,
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
_ = json.NewEncoder(w).Encode(listUsersResponse{Users: out})
|
||||
}
|
||||
|
||||
type createUserRequest struct {
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type createUserResponse struct {
|
||||
ID string `json:"id"`
|
||||
SetupURL string `json:"setup_url"`
|
||||
}
|
||||
|
||||
// generateSetupToken returns 32 random bytes hex-encoded (64 chars).
|
||||
func generateSetupToken() (string, error) {
|
||||
var b [32]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b[:]), nil
|
||||
}
|
||||
|
||||
// validRole maps a wire role string to the typed constant. Returns
|
||||
// ("", false) for anything unknown.
|
||||
func validRole(r string) (store.Role, bool) {
|
||||
switch r {
|
||||
case "admin":
|
||||
return store.RoleAdmin, true
|
||||
case "operator":
|
||||
return store.RoleOperator, true
|
||||
case "viewer":
|
||||
return store.RoleViewer, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserCreate(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r) // already gated by middleware
|
||||
var req createUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
uname := strings.ToLower(strings.TrimSpace(req.Username))
|
||||
if uname == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "username_required", "")
|
||||
return
|
||||
}
|
||||
role, ok := validRole(req.Role)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_role", "")
|
||||
return
|
||||
}
|
||||
if req.Email != "" {
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_email", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for collision against existing user (case-insensitive).
|
||||
existing, err := s.deps.Store.GetUserByUsername(r.Context(), uname)
|
||||
if err == nil {
|
||||
body := map[string]any{
|
||||
"error": "username_taken",
|
||||
"existing_user_id": existing.ID,
|
||||
"disabled": existing.DisabledAt != nil,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(stdhttp.StatusConflict)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
return
|
||||
} else if !errors.Is(err, store.ErrNotFound) {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
id := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
var emailPtr *string
|
||||
if req.Email != "" {
|
||||
em := strings.ToLower(strings.TrimSpace(req.Email))
|
||||
emailPtr = &em
|
||||
}
|
||||
if err := s.deps.Store.CreateUser(r.Context(), store.User{
|
||||
ID: id, Username: uname, PasswordHash: "",
|
||||
Role: role, Email: emailPtr, CreatedAt: now,
|
||||
MustChangePassword: true,
|
||||
}); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rawToken, err := generateSetupToken()
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
if err := s.deps.Store.SetSetupToken(r.Context(), store.SetupToken{
|
||||
UserID: id, TokenHash: hashSetupToken(rawToken),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
CreatedAt: now, CreatedBy: actorID,
|
||||
}); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.created", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: now,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(stdhttp.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(createUserResponse{
|
||||
ID: id,
|
||||
SetupURL: s.deps.Cfg.BaseURL + "/setup?token=" + rawToken,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserGet(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
u, err := s.deps.Store.GetUserByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "user_not_found", "")
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
out := apiUser{
|
||||
ID: u.ID, Username: u.Username, Role: string(u.Role),
|
||||
Email: u.Email, Disabled: u.DisabledAt != nil,
|
||||
MustChangePassword: u.MustChangePassword,
|
||||
CreatedAt: u.CreatedAt.UTC().Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
if u.LastLoginAt != nil {
|
||||
ll := u.LastLoginAt.UTC().Format("2006-01-02T15:04:05Z")
|
||||
out.LastLoginAt = &ll
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
_ = json.NewEncoder(w).Encode(out)
|
||||
}
|
||||
|
||||
type patchUserRequest struct {
|
||||
Role *string `json:"role,omitempty"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserPatch(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r)
|
||||
id := chi.URLParam(r, "id")
|
||||
u, err := s.deps.Store.GetUserByID(r.Context(), id)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "user_not_found", "")
|
||||
return
|
||||
}
|
||||
var req patchUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
if req.Role != nil {
|
||||
newRole, ok := validRole(*req.Role)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_role", "")
|
||||
return
|
||||
}
|
||||
// Last-admin guard: cannot demote the only enabled admin.
|
||||
if u.Role == store.RoleAdmin && newRole != store.RoleAdmin && u.DisabledAt == nil {
|
||||
n, _ := s.deps.Store.CountEnabledAdmins(r.Context())
|
||||
if n <= 1 {
|
||||
writeJSONError(w, stdhttp.StatusConflict, "last_admin", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.deps.Store.SetUserRole(r.Context(), id, newRole); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.Email != nil {
|
||||
em := strings.TrimSpace(*req.Email)
|
||||
if em != "" {
|
||||
if _, err := mail.ParseAddress(em); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_email", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.deps.Store.SetUserEmail(r.Context(), id, em); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.updated", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserDisable(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r)
|
||||
id := chi.URLParam(r, "id")
|
||||
u, err := s.deps.Store.GetUserByID(r.Context(), id)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "user_not_found", "")
|
||||
return
|
||||
}
|
||||
if u.Role == store.RoleAdmin && u.DisabledAt == nil {
|
||||
n, _ := s.deps.Store.CountEnabledAdmins(r.Context())
|
||||
if n <= 1 {
|
||||
writeJSONError(w, stdhttp.StatusConflict, "last_admin", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := s.deps.Store.DisableUser(r.Context(), id, now); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
// Kick existing sessions so the user is bounced immediately.
|
||||
_, _ = s.deps.Store.DeleteSessionsByUserID(r.Context(), id)
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.disabled", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: now,
|
||||
})
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserEnable(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r)
|
||||
id := chi.URLParam(r, "id")
|
||||
if err := s.deps.Store.EnableUser(r.Context(), id); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.enabled", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}
|
||||
|
||||
type regenerateSetupResponse struct {
|
||||
SetupURL string `json:"setup_url"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserRegenerateSetup(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r)
|
||||
id := chi.URLParam(r, "id")
|
||||
if _, err := s.deps.Store.GetUserByID(r.Context(), id); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "user_not_found", "")
|
||||
return
|
||||
}
|
||||
rawToken, err := generateSetupToken()
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
if err := s.deps.Store.SetSetupToken(r.Context(), store.SetupToken{
|
||||
UserID: id, TokenHash: hashSetupToken(rawToken),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
CreatedAt: now, CreatedBy: actorID,
|
||||
}); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.SetMustChangePassword(r.Context(), id, true); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.setup_token.regenerated",
|
||||
TargetKind: ptr("user"), TargetID: &id, TS: now,
|
||||
})
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
_ = json.NewEncoder(w).Encode(regenerateSetupResponse{
|
||||
SetupURL: s.deps.Cfg.BaseURL + "/setup?token=" + rawToken,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIUserForceLogout(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
actor, _ := s.requireUser(r)
|
||||
id := chi.URLParam(r, "id")
|
||||
n, err := s.deps.Store.DeleteSessionsByUserID(r.Context(), id)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
var actorID *string
|
||||
if actor != nil {
|
||||
actorID = &actor.ID
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: actorID, Actor: "user",
|
||||
Action: "user.force_logout",
|
||||
TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
_ = json.NewEncoder(w).Encode(map[string]int64{"sessions_killed": n})
|
||||
}
|
||||
@@ -56,15 +56,9 @@ func (s *Server) authenticateAndSession(w stdhttp.ResponseWriter, r *stdhttp.Req
|
||||
// existence to a probing attacker.
|
||||
return nil, errInvalidCredentials
|
||||
}
|
||||
if u.AuthSource == "oidc" {
|
||||
return nil, errInvalidCredentials
|
||||
}
|
||||
if err := auth.VerifyPassword(u.PasswordHash, password); err != nil {
|
||||
return nil, errInvalidCredentials
|
||||
}
|
||||
if u.DisabledAt != nil {
|
||||
return nil, errInvalidCredentials
|
||||
}
|
||||
|
||||
token, err := auth.NewToken()
|
||||
if err != nil {
|
||||
@@ -143,7 +137,7 @@ func (s *Server) handleBootstrap(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
writeJSONError(w, stdhttp.StatusConflict, "already_initialised",
|
||||
writeJSONError(w, stdhttp.StatusConflict, "already_initialized",
|
||||
"a user already exists; bootstrap is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
// bootstrap_handler.go — public landing page for the first-run admin
|
||||
// flow. While the server has no users and still holds the in-memory
|
||||
// one-shot bootstrap token printed at startup, /bootstrap renders a
|
||||
// form that takes a username + password and creates the first admin.
|
||||
//
|
||||
// The operator never sees or types the token: the server already has
|
||||
// it in memory, so the UI handler uses it directly. The token printed
|
||||
// to stderr remains a break-glass fallback for the JSON
|
||||
// /api/bootstrap path.
|
||||
//
|
||||
// Routes (wired in server.go):
|
||||
//
|
||||
// GET /bootstrap → handleUIBootstrapGet
|
||||
// POST /bootstrap → handleUIBootstrapPost
|
||||
//
|
||||
// Both routes self-disable the moment a user row exists; subsequent
|
||||
// hits redirect to /login.
|
||||
package http
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
type bootstrapPage struct {
|
||||
Username string
|
||||
Error string
|
||||
}
|
||||
|
||||
func (s *Server) handleUIBootstrapGet(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.bootstrapAvailable(r) {
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
s.renderBootstrap(w, r, "", "")
|
||||
}
|
||||
|
||||
func (s *Server) handleUIBootstrapPost(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.bootstrapAvailable(r) {
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
stdhttp.Error(w, "bad request", stdhttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
username := r.PostForm.Get("username")
|
||||
pw := r.PostForm.Get("password")
|
||||
pw2 := r.PostForm.Get("password_confirm")
|
||||
|
||||
if username == "" {
|
||||
s.renderBootstrap(w, r, username, "Pick a username.")
|
||||
return
|
||||
}
|
||||
if pw == "" || pw2 == "" || pw != pw2 || len(pw) < 12 {
|
||||
s.renderBootstrap(w, r, username,
|
||||
"Passwords must match and be at least 12 characters.")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(pw)
|
||||
if err != nil {
|
||||
slog.Error("bootstrap: hash password", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
u := store.User{
|
||||
ID: ulid.Make().String(),
|
||||
Username: username,
|
||||
PasswordHash: hash,
|
||||
Role: store.RoleAdmin,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := s.deps.Store.CreateUser(r.Context(), u); err != nil {
|
||||
slog.Error("bootstrap: create user", "err", err)
|
||||
s.renderBootstrap(w, r, username,
|
||||
"Could not create the administrator account. Check the server logs.")
|
||||
return
|
||||
}
|
||||
// Clear the in-memory token so /api/bootstrap also stops accepting
|
||||
// further calls. CountUsers > 0 already gates both surfaces, but
|
||||
// blanking the token kills the constant-time-compare branch as
|
||||
// well — defence in depth, plus stops the token from sitting in
|
||||
// process memory longer than necessary.
|
||||
s.deps.BootstrapToken = ""
|
||||
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: &u.ID,
|
||||
Actor: "system",
|
||||
Action: "auth.bootstrap",
|
||||
TS: now,
|
||||
})
|
||||
|
||||
// Mint a session so the new admin lands authenticated on /.
|
||||
rawSession, err := auth.NewToken()
|
||||
if err != nil {
|
||||
slog.Error("bootstrap: session token", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.CreateSession(r.Context(), store.Session{
|
||||
UserID: u.ID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(sessionTTL),
|
||||
IP: r.RemoteAddr,
|
||||
UA: r.UserAgent(),
|
||||
}, auth.HashToken(rawSession)); err != nil {
|
||||
slog.Error("bootstrap: create session", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.MarkUserLogin(r.Context(), u.ID, now)
|
||||
|
||||
stdhttp.SetCookie(w, &stdhttp.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: rawSession,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: s.deps.Cfg.CookieSecure,
|
||||
SameSite: stdhttp.SameSiteLaxMode,
|
||||
Expires: now.Add(sessionTTL),
|
||||
})
|
||||
stdhttp.Redirect(w, r, "/", stdhttp.StatusSeeOther)
|
||||
}
|
||||
|
||||
// bootstrapAvailable reports whether a fresh-install bootstrap can
|
||||
// still proceed: a one-shot token is held in memory and no user rows
|
||||
// exist yet.
|
||||
func (s *Server) bootstrapAvailable(r *stdhttp.Request) bool {
|
||||
if s.deps.BootstrapToken == "" {
|
||||
return false
|
||||
}
|
||||
n, err := s.deps.Store.CountUsers(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("bootstrap: count users", "err", err)
|
||||
return false
|
||||
}
|
||||
return n == 0
|
||||
}
|
||||
|
||||
func (s *Server) renderBootstrap(w stdhttp.ResponseWriter, r *stdhttp.Request, username, errMsg string) {
|
||||
view := s.baseView(r, nil)
|
||||
view.Title = "Welcome · restic-manager"
|
||||
view.Page = bootstrapPage{Username: username, Error: errMsg}
|
||||
if err := s.deps.UI.Render(w, "bootstrap", view); err != nil {
|
||||
slog.Error("ui bootstrap: render", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
stdhttp "net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// handleCancelJob is POST /api/jobs/{id}/cancel. Sends a command.cancel
|
||||
// envelope to the host that owns the job; the agent kills the running
|
||||
// restic subprocess, and the resulting job.finished envelope (status =
|
||||
// canceled) is what actually transitions the job row — this handler
|
||||
// does not touch the jobs table directly. Returning 202 makes that
|
||||
// asynchronicity explicit.
|
||||
//
|
||||
// 4xx cases:
|
||||
// - job not found (404)
|
||||
// - job already in a terminal state (409 — nothing to cancel)
|
||||
// - host offline (503 — same code path the run-now endpoint uses)
|
||||
//
|
||||
// Audit-logged as job.cancel with the job ID as target.
|
||||
func (s *Server) handleCancelJob(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
jobID := chi.URLParam(r, "id")
|
||||
if jobID == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_job_id", "")
|
||||
return
|
||||
}
|
||||
|
||||
job, err := s.deps.Store.GetJob(r.Context(), jobID)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "job_not_found", "")
|
||||
return
|
||||
}
|
||||
switch api.JobStatus(job.Status) {
|
||||
case api.JobSucceeded, api.JobFailed, api.JobCancelled:
|
||||
writeJSONError(w, stdhttp.StatusConflict, "job_terminal",
|
||||
"job is already in a terminal state ("+job.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
if !s.deps.Hub.Connected(job.HostID) {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "host_offline",
|
||||
"agent is not connected; can't deliver cancel signal")
|
||||
return
|
||||
}
|
||||
|
||||
env, err := api.Marshal(api.MsgCommandCancel, jobID, api.CommandCancelPayload{
|
||||
JobID: jobID,
|
||||
})
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
if err := s.deps.Hub.Send(r.Context(), job.HostID, env); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "host_offline", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var actorID *string
|
||||
actor := "system"
|
||||
if user != nil {
|
||||
actor = "user"
|
||||
actorID = &user.ID
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: actorID,
|
||||
Actor: actor,
|
||||
Action: "job.cancel",
|
||||
TargetKind: ptr("job"),
|
||||
TargetID: &jobID,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
|
||||
w.WriteHeader(stdhttp.StatusAccepted)
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
// cancel_test.go — covers POST /api/jobs/{id}/cancel.
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// TestCancelJobRunningHappyPath: a running job's cancel endpoint sends
|
||||
// a command.cancel envelope with the right job id, returns 202, and
|
||||
// writes a job.cancel audit row.
|
||||
func TestCancelJobRunningHappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, token := enrolHostForWS(t, srv, st, "cancel-host")
|
||||
c := agentDial(t, srv, ts, hostID, token)
|
||||
sendHello(t, c, "cancel-host")
|
||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||
|
||||
// Seed a running job we can target.
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateJob(context.Background(), store.Job{
|
||||
ID: jobID, HostID: hostID, Kind: "backup",
|
||||
ActorKind: "user", CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create job: %v", err)
|
||||
}
|
||||
if err := st.MarkJobStarted(context.Background(), jobID, now); err != nil {
|
||||
t.Fatalf("mark started: %v", err)
|
||||
}
|
||||
|
||||
cookie := loginAsAdmin(t, st)
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/api/jobs/"+jobID+"/cancel", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusAccepted {
|
||||
t.Fatalf("status: got %d, want 202", res.StatusCode)
|
||||
}
|
||||
|
||||
// Read the dispatched command.cancel envelope.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
var got api.Envelope
|
||||
for time.Now().Before(deadline) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
mt, raw, rerr := c.Read(ctx)
|
||||
cancel()
|
||||
if rerr != nil {
|
||||
break
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(raw), `"command.cancel"`) {
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(raw, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if got.Type != api.MsgCommandCancel {
|
||||
t.Fatalf("never received command.cancel envelope")
|
||||
}
|
||||
var cp api.CommandCancelPayload
|
||||
if err := got.UnmarshalPayload(&cp); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
if cp.JobID != jobID {
|
||||
t.Fatalf("payload job_id: got %q want %q", cp.JobID, jobID)
|
||||
}
|
||||
|
||||
// Audit row exists.
|
||||
var n int
|
||||
if err := st.DB().QueryRow(
|
||||
`SELECT COUNT(*) FROM audit_log WHERE action = 'job.cancel' AND target_id = ?`,
|
||||
jobID).Scan(&n); err != nil {
|
||||
t.Fatalf("audit count: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("audit rows: got %d, want 1", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelJobAlreadyTerminal: a job in succeeded/failed/canceled
|
||||
// state returns 409 and does NOT send a WS envelope.
|
||||
func TestCancelJobAlreadyTerminal(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, token := enrolHostForWS(t, srv, st, "term-host")
|
||||
c := agentDial(t, srv, ts, hostID, token)
|
||||
sendHello(t, c, "term-host")
|
||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateJob(context.Background(), store.Job{
|
||||
ID: jobID, HostID: hostID, Kind: "backup",
|
||||
ActorKind: "user", CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create job: %v", err)
|
||||
}
|
||||
if err := st.MarkJobFinished(context.Background(), jobID, "succeeded", 0, nil, "", now); err != nil {
|
||||
t.Fatalf("mark finished: %v", err)
|
||||
}
|
||||
|
||||
cookie := loginAsAdmin(t, st)
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/api/jobs/"+jobID+"/cancel", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusConflict {
|
||||
t.Fatalf("status: got %d, want 409", res.StatusCode)
|
||||
}
|
||||
|
||||
// Drain — no command.cancel should arrive.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
|
||||
defer cancel()
|
||||
for {
|
||||
mt, raw, rerr := c.Read(ctx)
|
||||
if rerr != nil {
|
||||
break
|
||||
}
|
||||
if mt == websocket.MessageText && strings.Contains(string(raw), `"command.cancel"`) {
|
||||
t.Fatalf("unexpected command.cancel envelope for terminal job")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelJobNotFound: 404 for a job id that doesn't exist.
|
||||
func TestCancelJobNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ts, st := rawTestServer(t)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/api/jobs/"+ulid.Make().String()+"/cancel", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusNotFound {
|
||||
t.Fatalf("status: got %d, want 404", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelJobHostOffline: a queued/running job whose host has no
|
||||
// active WS connection returns 503.
|
||||
func TestCancelJobHostOffline(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ts, st := rawTestServer(t)
|
||||
// Create a host but don't connect a WS for it.
|
||||
hostID := ulid.Make().String()
|
||||
if err := st.CreateHost(context.Background(), store.Host{
|
||||
ID: hostID, Name: "offline-host", OS: "linux", Arch: "amd64",
|
||||
EnrolledAt: time.Now().UTC(),
|
||||
}, "deadbeef", ""); err != nil {
|
||||
t.Fatalf("create host: %v", err)
|
||||
}
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateJob(context.Background(), store.Job{
|
||||
ID: jobID, HostID: hostID, Kind: "backup",
|
||||
ActorKind: "user", CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create job: %v", err)
|
||||
}
|
||||
if err := st.MarkJobStarted(context.Background(), jobID, now); err != nil {
|
||||
t.Fatalf("mark started: %v", err)
|
||||
}
|
||||
|
||||
cookie := loginAsAdmin(t, st)
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/api/jobs/"+jobID+"/cancel", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusServiceUnavailable {
|
||||
t.Fatalf("status: got %d, want 503", res.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
// dashboard_filter_test.go — covers the NS-04 filter + sort pipeline
|
||||
// in pure-Go form, without going through HTTP. The handler tests
|
||||
// elsewhere prove end-to-end render; here we focus on edge cases of
|
||||
// the column-sort + filter precedence so a regression in either is
|
||||
// surfaced loudly.
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func makeFilterHosts() []store.Host {
|
||||
t1 := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC)
|
||||
t2 := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC)
|
||||
tSeen := time.Date(2026, 5, 5, 12, 0, 0, 0, time.UTC)
|
||||
return []store.Host{
|
||||
{
|
||||
ID: "01HHA", Name: "alpha", OS: "linux", Status: "online",
|
||||
RepoStatus: "ready", Tags: []string{"prod"}, SnapshotCount: 30,
|
||||
LastBackupAt: &t1, LastSeenAt: &tSeen, RepoSizeBytes: 1000,
|
||||
},
|
||||
{
|
||||
ID: "01HHB", Name: "bravo", OS: "linux", Status: "offline",
|
||||
RepoStatus: "init_failed", Tags: []string{"dev"}, SnapshotCount: 10,
|
||||
LastBackupAt: &t2, LastSeenAt: &tSeen, RepoSizeBytes: 5000,
|
||||
},
|
||||
{
|
||||
ID: "01HHC", Name: "charlie", OS: "windows", Status: "online",
|
||||
RepoStatus: "unknown", Tags: []string{"prod", "edge"}, SnapshotCount: 0,
|
||||
LastSeenAt: nil, // never_seen path
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterAndSortDashboardSearchAndStatus covers the precedence of
|
||||
// search ∧ status as combined filters.
|
||||
func TestFilterAndSortDashboardSearchAndStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
hosts := makeFilterHosts()
|
||||
|
||||
// status=online narrows to alpha + charlie.
|
||||
got := filterAndSortDashboardHosts(hosts, dashboardFilter{Status: "online", Sort: "name", Dir: "asc"})
|
||||
if len(got) != 2 || got[0].Name != "alpha" || got[1].Name != "charlie" {
|
||||
t.Errorf("status=online: got %d names %v, want [alpha charlie]", len(got), namesOf(got))
|
||||
}
|
||||
// q=bra narrows to bravo regardless of status default.
|
||||
got = filterAndSortDashboardHosts(hosts, dashboardFilter{Search: "bra", Sort: "name", Dir: "asc"})
|
||||
if len(got) != 1 || got[0].Name != "bravo" {
|
||||
t.Errorf("search=bra: got %v", namesOf(got))
|
||||
}
|
||||
// repo_status=init_failed narrows to bravo only.
|
||||
got = filterAndSortDashboardHosts(hosts, dashboardFilter{RepoStatus: "init_failed", Sort: "name", Dir: "asc"})
|
||||
if len(got) != 1 || got[0].Name != "bravo" {
|
||||
t.Errorf("repo_status=init_failed: got %v", namesOf(got))
|
||||
}
|
||||
// status=never_seen narrows on LastSeenAt == nil → charlie only.
|
||||
got = filterAndSortDashboardHosts(hosts, dashboardFilter{Status: "never_seen", Sort: "name", Dir: "asc"})
|
||||
if len(got) != 1 || got[0].Name != "charlie" {
|
||||
t.Errorf("status=never_seen: got %v", namesOf(got))
|
||||
}
|
||||
// tag=prod narrows to alpha + charlie.
|
||||
got = filterAndSortDashboardHosts(hosts, dashboardFilter{Tag: "prod", Sort: "name", Dir: "asc"})
|
||||
if len(got) != 2 || got[0].Name != "alpha" || got[1].Name != "charlie" {
|
||||
t.Errorf("tag=prod: got %v", namesOf(got))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSortDashboardHostsColumns verifies each meaningful column
|
||||
// sorts as expected, both ascending and descending.
|
||||
func TestSortDashboardHostsColumns(t *testing.T) {
|
||||
t.Parallel()
|
||||
hosts := makeFilterHosts()
|
||||
|
||||
cases := []struct {
|
||||
col, dir string
|
||||
want []string
|
||||
}{
|
||||
{"name", "asc", []string{"alpha", "bravo", "charlie"}},
|
||||
{"name", "desc", []string{"charlie", "bravo", "alpha"}},
|
||||
{"snapshot_count", "asc", []string{"charlie", "bravo", "alpha"}},
|
||||
{"snapshot_count", "desc", []string{"alpha", "bravo", "charlie"}},
|
||||
{"last_backup", "asc", []string{"charlie", "alpha", "bravo"}}, // nil → zero → first
|
||||
{"repo_status", "asc", []string{"bravo", "alpha", "charlie"}}, // init_failed < ready < unknown
|
||||
}
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
t.Run(c.col+"_"+c.dir, func(t *testing.T) {
|
||||
got := append([]store.Host(nil), hosts...)
|
||||
sortDashboardHosts(got, c.col, c.dir)
|
||||
if names := namesOf(got); !sliceEq(names, c.want) {
|
||||
t.Errorf("got %v, want %v", names, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseDashboardFilterDefaults: empty query gives sort=name asc.
|
||||
func TestParseDashboardFilterDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := parseDashboardFilter(url.Values{})
|
||||
if f.Sort != "name" || f.Dir != "asc" {
|
||||
t.Errorf("defaults: got sort=%q dir=%q, want name/asc", f.Sort, f.Dir)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildDashboardSortURLsToggles: clicking the active column
|
||||
// flips direction; clicking another column resets to asc.
|
||||
func TestBuildDashboardSortURLsToggles(t *testing.T) {
|
||||
t.Parallel()
|
||||
active := dashboardFilter{Sort: "name", Dir: "asc"}
|
||||
urls := buildDashboardSortURLs(active)
|
||||
if got := urls["name"]; got != "/?dir=desc" {
|
||||
t.Errorf("name URL on active asc: got %q, want /?dir=desc", got)
|
||||
}
|
||||
// Switching to a non-default column also drops dir=asc since asc
|
||||
// is the encoded default.
|
||||
if got := urls["last_backup"]; got != "/?sort=last_backup" {
|
||||
t.Errorf("last_backup URL: got %q, want /?sort=last_backup", got)
|
||||
}
|
||||
}
|
||||
|
||||
func namesOf(hs []store.Host) []string {
|
||||
out := make([]string, len(hs))
|
||||
for i, h := range hs {
|
||||
out[i] = h.Name
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sliceEq(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// snapshotDiffRequest is the JSON body for POST .../snapshots/diff.
|
||||
// Either short or long snapshot IDs are accepted (restic's diff
|
||||
// command takes both).
|
||||
type snapshotDiffRequest struct {
|
||||
SnapshotA string `json:"snapshot_a"`
|
||||
SnapshotB string `json:"snapshot_b"`
|
||||
}
|
||||
|
||||
// handleSnapshotDiff dispatches a JobDiff. Output streams as
|
||||
// log.stream lines to the standard live job page; the operator reads
|
||||
// the diff text directly there. Behaves like the run-now endpoints:
|
||||
// 503 if the host is offline, 400 if the IDs are missing, 422 if
|
||||
// they're not in the host's snapshot list (we don't want operators
|
||||
// running diffs against arbitrary snapshot strings).
|
||||
func (s *Server) handleSnapshotDiff(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
host, err := s.deps.Store.GetHost(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "host_not_found", "")
|
||||
return
|
||||
}
|
||||
|
||||
var req snapshotDiffRequest
|
||||
// HTMX form posts arrive as application/x-www-form-urlencoded;
|
||||
// the JSON shape is also accepted for REST callers.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_form", err.Error())
|
||||
return
|
||||
}
|
||||
req.SnapshotA = strings.TrimSpace(r.PostForm.Get("snapshot_a"))
|
||||
req.SnapshotB = strings.TrimSpace(r.PostForm.Get("snapshot_b"))
|
||||
} else {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
req.SnapshotA = strings.TrimSpace(req.SnapshotA)
|
||||
req.SnapshotB = strings.TrimSpace(req.SnapshotB)
|
||||
}
|
||||
if req.SnapshotA == "" || req.SnapshotB == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_snapshot",
|
||||
"snapshot_a and snapshot_b are both required")
|
||||
return
|
||||
}
|
||||
if req.SnapshotA == req.SnapshotB {
|
||||
writeJSONError(w, stdhttp.StatusUnprocessableEntity, "same_snapshot",
|
||||
"diff requires two different snapshots")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the IDs are known to this host. Match on long ID, short
|
||||
// ID, or any prefix match — operators sometimes paste a 6-char
|
||||
// shortened form.
|
||||
snaps, err := s.deps.Store.ListSnapshotsByHost(r.Context(), host.ID)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
resolveID := func(idOrShort string) string {
|
||||
for _, s := range snaps {
|
||||
if s.ID == idOrShort || s.ShortID == idOrShort {
|
||||
return s.ID
|
||||
}
|
||||
}
|
||||
// Prefix fallback (operator pasted 6 chars of a long id).
|
||||
for _, s := range snaps {
|
||||
if strings.HasPrefix(s.ID, idOrShort) {
|
||||
return s.ID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
a := resolveID(req.SnapshotA)
|
||||
b := resolveID(req.SnapshotB)
|
||||
if a == "" || b == "" {
|
||||
writeJSONError(w, stdhttp.StatusUnprocessableEntity, "snapshot_not_found",
|
||||
"one or both snapshot ids are not in this host's snapshot list")
|
||||
return
|
||||
}
|
||||
|
||||
if !s.deps.Hub.Connected(host.ID) {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "host_offline",
|
||||
"agent is not connected; try again when it reconnects")
|
||||
return
|
||||
}
|
||||
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := s.deps.Store.CreateJob(r.Context(), store.Job{
|
||||
ID: jobID, HostID: host.ID, Kind: string(api.JobDiff),
|
||||
ActorKind: "user", ActorID: &user.ID, CreatedAt: now,
|
||||
}); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
env, err := api.Marshal(api.MsgCommandRun, jobID, api.CommandRunPayload{
|
||||
JobID: jobID, Kind: api.JobDiff,
|
||||
Diff: &api.DiffPayload{SnapshotA: a, SnapshotB: b},
|
||||
})
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
if err := s.deps.Hub.Send(r.Context(), host.ID, env); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusServiceUnavailable, "host_offline", err.Error())
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: &user.ID,
|
||||
Actor: "user",
|
||||
Action: "host.snapshot_diff",
|
||||
TargetKind: ptr("host"),
|
||||
TargetID: &host.ID,
|
||||
TS: now,
|
||||
})
|
||||
|
||||
jobURL := "/jobs/" + jobID
|
||||
if r.Header.Get("HX-Request") == "true" {
|
||||
w.Header().Set("HX-Redirect", jobURL)
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
return
|
||||
}
|
||||
writeJSON(w, stdhttp.StatusAccepted, map[string]string{
|
||||
"job_id": jobID,
|
||||
"job_url": jobURL,
|
||||
})
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
// diff_test.go — covers POST /api/hosts/{id}/snapshots/diff (P3-09).
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
)
|
||||
|
||||
// TestSnapshotDiffHappyPath verifies a valid two-snapshot form ships
|
||||
// a JobDiff command.run with the right payload.
|
||||
func TestSnapshotDiffHappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServerWithUI(t)
|
||||
hostID, token := enrolHostForUI(t, srv, st, "diff-host")
|
||||
a, b := seedTwoSnapshots(t, st, hostID, "diff-host")
|
||||
c := agentDial(t, srv, ts, hostID, token)
|
||||
sendHello(t, c, "diff-host")
|
||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
|
||||
form := url.Values{
|
||||
"snapshot_a": {a},
|
||||
"snapshot_b": {b},
|
||||
}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/hosts/"+hostID+"/snapshots/diff",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("HX-Request", "true")
|
||||
req.AddCookie(cookie)
|
||||
client := &stdhttp.Client{
|
||||
CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
||||
return stdhttp.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusNoContent {
|
||||
t.Fatalf("status: got %d, want 204", res.StatusCode)
|
||||
}
|
||||
if res.Header.Get("HX-Redirect") == "" {
|
||||
t.Fatal("expected HX-Redirect to live job page")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
var got api.Envelope
|
||||
for time.Now().Before(deadline) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
mt, raw, rerr := c.Read(ctx)
|
||||
cancel()
|
||||
if rerr != nil {
|
||||
break
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(raw), `"kind":"diff"`) {
|
||||
continue
|
||||
}
|
||||
_ = json.Unmarshal(raw, &got)
|
||||
break
|
||||
}
|
||||
if got.Type != api.MsgCommandRun {
|
||||
t.Fatal("never received diff command.run")
|
||||
}
|
||||
var cp api.CommandRunPayload
|
||||
_ = got.UnmarshalPayload(&cp)
|
||||
if cp.Diff == nil {
|
||||
t.Fatal("diff payload nil")
|
||||
}
|
||||
if cp.Diff.SnapshotA != a || cp.Diff.SnapshotB != b {
|
||||
t.Fatalf("diff payload: got %+v want a=%s b=%s", cp.Diff, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSnapshotDiffSameID rejects diff(a,a) with 422.
|
||||
func TestSnapshotDiffSameID(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServerWithUI(t)
|
||||
hostID, _ := enrolHostForUI(t, srv, st, "diff-same")
|
||||
a := seedSnapshot(t, st, hostID, "diff-same")
|
||||
cookie := loginAsAdmin(t, st)
|
||||
|
||||
form := url.Values{"snapshot_a": {a}, "snapshot_b": {a}}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/hosts/"+hostID+"/snapshots/diff",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnprocessableEntity {
|
||||
t.Fatalf("status: got %d, want 422", res.StatusCode)
|
||||
}
|
||||
_ = srv
|
||||
}
|
||||
|
||||
// TestSnapshotDiffUnknownID rejects ids not in the host's snapshot list.
|
||||
func TestSnapshotDiffUnknownID(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServerWithUI(t)
|
||||
hostID, _ := enrolHostForUI(t, srv, st, "diff-unknown")
|
||||
_ = seedSnapshot(t, st, hostID, "diff-unknown")
|
||||
cookie := loginAsAdmin(t, st)
|
||||
|
||||
form := url.Values{"snapshot_a": {"deadbeef"}, "snapshot_b": {"cafebabe"}}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/hosts/"+hostID+"/snapshots/diff",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnprocessableEntity {
|
||||
t.Fatalf("status: got %d, want 422", res.StatusCode)
|
||||
}
|
||||
_ = srv
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request)
|
||||
// session cookie and trust it, validating the cookie via store.
|
||||
func (s *Server) handleCreateEnrollmentToken(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
// hooks_resolve.go — server-side resolution of pre/post hooks for a
|
||||
// backup dispatch (P2R-11). The agent receives plaintext hook bodies
|
||||
// in CommandRunPayload; this file is where the AEAD blob on the
|
||||
// source group (or the host's default) gets decrypted into the
|
||||
// strings the wire payload carries.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. source_group.<phase>_hook (per-group override)
|
||||
// 2. host.<phase>_hook_default (host-wide default)
|
||||
// 3. "" (no hook → agent skips that phase)
|
||||
//
|
||||
// Decrypt errors are logged and treated as "no hook configured" so
|
||||
// a malformed blob can't poison every backup. The audit trail
|
||||
// captures the underlying state regardless.
|
||||
package http
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// resolveBackupHooks returns the (pre, post) plaintext hook strings
|
||||
// the agent should run around the backup. Both are empty when no
|
||||
// hook is configured at either level.
|
||||
func (s *Server) resolveBackupHooks(host *store.Host, g *store.SourceGroup) (pre, post string) {
|
||||
if s.deps.AEAD == nil {
|
||||
return "", ""
|
||||
}
|
||||
pre = s.decryptHookOrFallback(g.PreHook, host.PreHookDefault, host.ID, "pre")
|
||||
post = s.decryptHookOrFallback(g.PostHook, host.PostHookDefault, host.ID, "post")
|
||||
return pre, post
|
||||
}
|
||||
|
||||
// decryptHookOrFallback returns the per-group hook decrypted, or
|
||||
// (when that's empty) the host default decrypted, or "" if neither
|
||||
// is configured. Decrypt failures log and degrade to empty.
|
||||
func (s *Server) decryptHookOrFallback(group, hostDefault, hostID, phase string) string {
|
||||
tryDecrypt := func(blob, slot string) (string, bool) {
|
||||
if blob == "" {
|
||||
return "", false
|
||||
}
|
||||
plain, err := s.deps.AEAD.Decrypt(blob, []byte("hook:"+hostID+":"+slot+":"+phase))
|
||||
if err != nil {
|
||||
slog.Error("decrypt hook", "host_id", hostID, "phase", phase, "slot", slot, "err", err)
|
||||
return "", false
|
||||
}
|
||||
return string(plain), true
|
||||
}
|
||||
if v, ok := tryDecrypt(group, "group"); ok {
|
||||
return v
|
||||
}
|
||||
if v, ok := tryDecrypt(hostDefault, "host"); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// EncryptHookForGroup encrypts a hook body for storage on a source
|
||||
// group. Caller passes the plaintext from a UI form; an empty body
|
||||
// returns "" so the store persists NULL (cleared).
|
||||
func (s *Server) EncryptHookForGroup(hostID, phase, body string) (string, error) {
|
||||
if body == "" {
|
||||
return "", nil
|
||||
}
|
||||
return s.deps.AEAD.Encrypt([]byte(body), []byte("hook:"+hostID+":group:"+phase))
|
||||
}
|
||||
|
||||
// EncryptHookForHost is the host-default twin of EncryptHookForGroup.
|
||||
func (s *Server) EncryptHookForHost(hostID, phase, body string) (string, error) {
|
||||
if body == "" {
|
||||
return "", nil
|
||||
}
|
||||
return s.deps.AEAD.Encrypt([]byte(body), []byte("hook:"+hostID+":host:"+phase))
|
||||
}
|
||||
@@ -27,7 +27,7 @@ type hostBandwidthView struct {
|
||||
|
||||
func (s *Server) handleUpdateHostBandwidth(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -58,10 +58,5 @@ func (s *Server) handleUpdateHostBandwidth(w stdhttp.ResponseWriter, r *stdhttp.
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
// Fan out to the agent if connected. Errors are non-fatal — the
|
||||
// next reconnect's onAgentHello will resync.
|
||||
if s.deps.Hub != nil && s.deps.Hub.Connected(hostID) {
|
||||
_ = s.pushBandwidthToAgent(r.Context(), hostID, req.BandwidthUpKBps, req.BandwidthDownKBps)
|
||||
}
|
||||
writeJSON(w, stdhttp.StatusOK, hostBandwidthView(req))
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
// host_bandwidth_push.go — server → agent fan-out of host-wide
|
||||
// bandwidth caps via config.update.
|
||||
//
|
||||
// Two entry points: pushBandwidthOnHello (called from onAgentHello,
|
||||
// always pushes the current state so the agent picks up edits made
|
||||
// while it was offline) and pushBandwidthToAgent (called after the
|
||||
// PUT bandwidth handler succeeds, so an online agent re-arms within
|
||||
// seconds).
|
||||
//
|
||||
// We always send pointer fields (zero-valued when uncapped) so the
|
||||
// agent can distinguish "no change" (nil → field absent on the wire)
|
||||
// from "explicitly cleared" (non-nil zero pointer). See
|
||||
// api.ConfigUpdatePayload doc for the wire semantics.
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
)
|
||||
|
||||
// pushBandwidthOnHello ships the host's current bandwidth caps as a
|
||||
// config.update on the supplied conn. Silent no-op on lookup error.
|
||||
func (s *Server) pushBandwidthOnHello(ctx context.Context, hostID string, conn *ws.Conn) {
|
||||
host, err := s.deps.Store.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
slog.Warn("on-hello: load host for bandwidth", "host_id", hostID, "err", err)
|
||||
return
|
||||
}
|
||||
payload := bandwidthPayload(host.BandwidthUpKBps, host.BandwidthDownKBps)
|
||||
env, err := api.Marshal(api.MsgConfigUpdate, "", payload)
|
||||
if err != nil {
|
||||
slog.Error("on-hello: marshal bandwidth config.update", "host_id", hostID, "err", err)
|
||||
return
|
||||
}
|
||||
sendCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := conn.Send(sendCtx, env); err != nil {
|
||||
slog.Warn("on-hello: send bandwidth config.update", "host_id", hostID, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// pushBandwidthToAgent ships the supplied caps via the hub. Caller is
|
||||
// expected to check Hub.Connected first when it matters.
|
||||
func (s *Server) pushBandwidthToAgent(ctx context.Context, hostID string, up, down *int) error {
|
||||
env, err := api.Marshal(api.MsgConfigUpdate, "", bandwidthPayload(up, down))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sendCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
return s.deps.Hub.Send(sendCtx, hostID, env)
|
||||
}
|
||||
|
||||
// bandwidthPayload builds a ConfigUpdatePayload with only the
|
||||
// bandwidth fields populated. Pointers are passed through verbatim;
|
||||
// callers wanting to clear a cap should pass a non-nil pointer to 0.
|
||||
// On the on-hello path we materialise zero-valued pointers when the
|
||||
// host record has no cap set, so the agent's stored state is always
|
||||
// in sync (rather than retaining whatever value it last received).
|
||||
func bandwidthPayload(up, down *int) api.ConfigUpdatePayload {
|
||||
zero := 0
|
||||
upPtr := up
|
||||
if upPtr == nil {
|
||||
upPtr = &zero
|
||||
}
|
||||
downPtr := down
|
||||
if downPtr == nil {
|
||||
downPtr = &zero
|
||||
}
|
||||
return api.ConfigUpdatePayload{
|
||||
BandwidthUpKBps: upPtr,
|
||||
BandwidthDownKBps: downPtr,
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ type hostRepoCredsView struct {
|
||||
// creds for UI display. 404 if no credential has ever been set.
|
||||
func (s *Server) handleGetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -88,7 +88,7 @@ type hostRepoCredsRequest struct {
|
||||
func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -146,15 +146,6 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R
|
||||
return
|
||||
}
|
||||
|
||||
// NS-03: clear the host's last probe outcome — the new creds may
|
||||
// reach a different repo (or fix an auth typo), so any prior
|
||||
// "init_failed" / "ready" tag is stale. The next init dispatch
|
||||
// (below, when the agent is online) will set it to a fresh value
|
||||
// on completion.
|
||||
if err := s.deps.Store.SetHostRepoStatus(r.Context(), hostID, "unknown", ""); err != nil {
|
||||
slog.Warn("repo creds set: reset repo_status", "host_id", hostID, "err", err)
|
||||
}
|
||||
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: &user.ID,
|
||||
@@ -169,66 +160,12 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R
|
||||
// the next reconnect will pick the row up via the hello handler.
|
||||
if s.deps.Hub != nil && s.deps.Hub.Connected(hostID) {
|
||||
_ = s.pushRepoCredsToAgent(r.Context(), hostID, existing)
|
||||
// Force a fresh probe so a typo / wrong URL surfaces now
|
||||
// rather than at the next scheduled job. No-op if offline —
|
||||
// the operator already saw "host offline" elsewhere.
|
||||
if err := s.dispatchInitJob(r.Context(), hostID, "user", &user.ID); err != nil {
|
||||
slog.Warn("repo creds set: dispatch init", "host_id", hostID, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
}
|
||||
|
||||
// dispatchInitJob creates an init job row, marshals the command.run,
|
||||
// ships it down the agent's WS connection (when connected), and
|
||||
// audits. NS-03 path: callers use this to force a fresh probe after
|
||||
// credentials change without waiting for the next hello — and without
|
||||
// the maybeAutoInit "first time only" guard. actorKind should be
|
||||
// "user" for operator-driven dispatches and "system" for the
|
||||
// auto-init-on-hello case so audit reflects intent.
|
||||
func (s *Server) dispatchInitJob(ctx context.Context, hostID, actorKind string, actorID *string) error {
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := s.deps.Store.CreateJob(ctx, store.Job{
|
||||
ID: jobID,
|
||||
HostID: hostID,
|
||||
Kind: string(api.JobInit),
|
||||
ActorKind: actorKind,
|
||||
ActorID: actorID,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("dispatch init: persist job: %w", err)
|
||||
}
|
||||
env, err := api.Marshal(api.MsgCommandRun, jobID, api.CommandRunPayload{
|
||||
JobID: jobID,
|
||||
Kind: api.JobInit,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("dispatch init: marshal: %w", err)
|
||||
}
|
||||
if s.deps.Hub != nil && s.deps.Hub.Connected(hostID) {
|
||||
sendCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.deps.Hub.Send(sendCtx, hostID, env); err != nil {
|
||||
// Job row stays — the host's pending-runs drain or the next
|
||||
// hello picks it up. We leave the slate clean for the caller.
|
||||
return fmt.Errorf("dispatch init: ws send: %w", err)
|
||||
}
|
||||
}
|
||||
_ = s.deps.Store.AppendAudit(ctx, store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: actorID,
|
||||
Actor: actorKind,
|
||||
Action: "host.repo_init_dispatched",
|
||||
TargetKind: ptr("host"),
|
||||
TargetID: &hostID,
|
||||
TS: now,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// pushRepoCredsToAgent serialises blob into a config.update envelope
|
||||
// pushRepoCredsToAgent serializes blob into a config.update envelope
|
||||
// and ships it down the agent's WS. Returns an error from the hub
|
||||
// (no-op if not connected — caller is expected to check first when it
|
||||
// matters).
|
||||
@@ -255,7 +192,7 @@ func (s *Server) pushRepoCredsToAgent(ctx context.Context, hostID string, blob r
|
||||
// uses this to pre-fill the edit form.
|
||||
func (s *Server) handleGetAdminCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -297,7 +234,7 @@ func (s *Server) handleGetAdminCredentials(w stdhttp.ResponseWriter, r *stdhttp.
|
||||
func (s *Server) handleSetAdminCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -382,7 +319,7 @@ func (s *Server) handleSetAdminCredentials(w stdhttp.ResponseWriter, r *stdhttp.
|
||||
func (s *Server) handleDeleteAdminCredentials(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -462,10 +399,6 @@ func (s *Server) pushAdminCredsToAgent(ctx context.Context, hostID string) error
|
||||
// don't race a brand-new register against an old still-closing conn.
|
||||
func (s *Server) onAgentHello(ctx context.Context, hostID string, conn *ws.Conn) {
|
||||
s.pushRepoCredsOnHello(ctx, hostID, conn)
|
||||
// Bandwidth caps are sent unconditionally so an agent that
|
||||
// reconnects after a cap edit picks up the new state without
|
||||
// waiting for the next bandwidth PUT.
|
||||
s.pushBandwidthOnHello(ctx, hostID, conn)
|
||||
// Push the current schedule set in the same on-hello window so
|
||||
// the agent's local cron is in sync before any command.run lands.
|
||||
// An empty schedule list is a valid push: it tells the agent to
|
||||
|
||||
@@ -34,7 +34,7 @@ type hostView struct {
|
||||
// see the same projection.
|
||||
func (s *Server) handleListHosts(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hosts, err := s.deps.Store.ListHosts(r.Context())
|
||||
@@ -55,7 +55,7 @@ func (s *Server) handleListHosts(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
// handleFleetSummary returns the dashboard tile aggregate.
|
||||
func (s *Server) handleFleetSummary(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
fs, err := s.deps.Store.FleetSummary(r.Context())
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// handleJobLogDownload is GET /api/jobs/{id}/log{.txt,.ndjson}.
|
||||
//
|
||||
// Source of truth is the persisted job_logs table — works any time,
|
||||
// regardless of whether the job is running or already finished. The
|
||||
// download is "everything the server has up to right now"; the live
|
||||
// stream is unaffected (no pause needed). If the operator wants a
|
||||
// fuller snapshot of a still-running job, they hit Download again.
|
||||
//
|
||||
// Format is picked from the URL suffix (.txt | .ndjson) for a
|
||||
// sensible filename in the browser, or the ?format= query param for
|
||||
// REST callers. Default is txt.
|
||||
func (s *Server) handleJobLogDownload(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if _, ok := s.requireUser(r); !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
jobID := chi.URLParam(r, "id")
|
||||
if jobID == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_job_id", "")
|
||||
return
|
||||
}
|
||||
job, err := s.deps.Store.GetJob(r.Context(), jobID)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "job_not_found", "")
|
||||
return
|
||||
}
|
||||
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
// Sniff the URL — chi routes both /log.txt and /log.ndjson here
|
||||
// (or .log if a future route adds it) via the {format} matcher.
|
||||
fmtParam := chi.URLParam(r, "format")
|
||||
switch fmtParam {
|
||||
case "ndjson":
|
||||
format = "ndjson"
|
||||
default:
|
||||
format = "txt"
|
||||
}
|
||||
}
|
||||
|
||||
logs, err := s.deps.Store.ListJobLogs(r.Context(), jobID, 0, 0)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
short := jobID
|
||||
if len(short) > 8 {
|
||||
short = short[:8]
|
||||
}
|
||||
filename := "job-" + job.Kind + "-" + short
|
||||
switch format {
|
||||
case "ndjson":
|
||||
w.Header().Set("Content-Type", "application/x-ndjson; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition",
|
||||
`attachment; filename="`+filename+`.ndjson"`)
|
||||
writeLogsNDJSON(w, logs)
|
||||
default:
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition",
|
||||
`attachment; filename="`+filename+`.txt"`)
|
||||
writeLogsText(w, job, logs)
|
||||
}
|
||||
}
|
||||
|
||||
// writeLogsText renders the logs in the same shape the live page shows:
|
||||
// "HH:MM:SS.mmm TAG payload". Adds a small header so the file is
|
||||
// useful as a standalone artefact (operator pastes it into a ticket).
|
||||
func writeLogsText(w stdhttp.ResponseWriter, job *store.Job, logs []store.JobLogLine) {
|
||||
bw := bufio.NewWriter(w)
|
||||
defer func() { _ = bw.Flush() }()
|
||||
_, _ = fmt.Fprintf(bw, "# job %s · kind %s · status %s\n",
|
||||
job.ID, job.Kind, job.Status)
|
||||
if job.StartedAt != nil {
|
||||
_, _ = fmt.Fprintf(bw, "# started %s\n", job.StartedAt.UTC().Format("2006-01-02T15:04:05.000Z"))
|
||||
}
|
||||
if job.FinishedAt != nil {
|
||||
_, _ = fmt.Fprintf(bw, "# finished %s\n", job.FinishedAt.UTC().Format("2006-01-02T15:04:05.000Z"))
|
||||
}
|
||||
_, _ = fmt.Fprintf(bw, "# %d log lines\n\n", len(logs))
|
||||
for _, l := range logs {
|
||||
tag := streamTag(l.Stream)
|
||||
ts := l.TS.UTC().Format("15:04:05.000")
|
||||
// Strip embedded newlines from payload — log lines should be
|
||||
// single-line, but defensive: a stray '\n' in stderr would
|
||||
// break grep -n.
|
||||
payload := strings.ReplaceAll(l.Payload, "\n", " ")
|
||||
_, _ = fmt.Fprintf(bw, "%s %s %s\n", ts, tag, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// writeLogsNDJSON emits one JSON object per line. Each object stands
|
||||
// alone — appending to the file remains valid NDJSON.
|
||||
func writeLogsNDJSON(w stdhttp.ResponseWriter, logs []store.JobLogLine) {
|
||||
enc := json.NewEncoder(w)
|
||||
for _, l := range logs {
|
||||
_ = enc.Encode(struct {
|
||||
Seq int64 `json:"seq"`
|
||||
TS string `json:"ts"`
|
||||
Stream string `json:"stream"`
|
||||
Payload string `json:"payload"`
|
||||
}{
|
||||
Seq: l.Seq,
|
||||
TS: l.TS.UTC().Format("2006-01-02T15:04:05.000Z"),
|
||||
Stream: l.Stream,
|
||||
Payload: l.Payload,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func streamTag(s string) string {
|
||||
switch s {
|
||||
case "stdout":
|
||||
return "OUT"
|
||||
case "stderr":
|
||||
return "ERR"
|
||||
case "event":
|
||||
return "EVENT"
|
||||
}
|
||||
return strings.ToUpper(s)
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
// job_download_test.go — covers GET /api/jobs/{id}/log.{txt,ndjson}.
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// seedJobWithLogs creates a job + a few log lines for it. Returns the
|
||||
// job ID. Caller is responsible for the test server + auth.
|
||||
func seedJobWithLogs(t *testing.T, st *store.Store, hostID string, lineCount int) string {
|
||||
t.Helper()
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
if err := st.CreateJob(context.Background(), store.Job{
|
||||
ID: jobID, HostID: hostID, Kind: "diff",
|
||||
ActorKind: "user", CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("create job: %v", err)
|
||||
}
|
||||
if err := st.MarkJobStarted(context.Background(), jobID, now); err != nil {
|
||||
t.Fatalf("mark started: %v", err)
|
||||
}
|
||||
for i := 0; i < lineCount; i++ {
|
||||
stream := "stdout"
|
||||
if i%5 == 0 {
|
||||
stream = "stderr"
|
||||
}
|
||||
payload := `{"message_type":"change","path":"/etc/file` +
|
||||
ulid.Make().String()[:6] + `","modifier":"M"}`
|
||||
if err := st.AppendJobLog(context.Background(), jobID, int64(i+1),
|
||||
now.Add(time.Duration(i)*time.Millisecond),
|
||||
stream, payload); err != nil {
|
||||
t.Fatalf("append log: %v", err)
|
||||
}
|
||||
}
|
||||
if err := st.MarkJobFinished(context.Background(), jobID, "succeeded", 0, nil, "", now); err != nil {
|
||||
t.Fatalf("mark finished: %v", err)
|
||||
}
|
||||
return jobID
|
||||
}
|
||||
|
||||
// TestJobLogDownloadTxt: plain-text format includes a header + one
|
||||
// line per log row in the expected shape.
|
||||
func TestJobLogDownloadTxt(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, _ := enrolHostForWS(t, srv, st, "dl-txt-host")
|
||||
jobID := seedJobWithLogs(t, st, hostID, 12)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
|
||||
req, _ := stdhttp.NewRequest("GET",
|
||||
ts.URL+"/api/jobs/"+jobID+"/log.txt", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusOK {
|
||||
t.Fatalf("status: got %d, want 200", res.StatusCode)
|
||||
}
|
||||
if ct := res.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") {
|
||||
t.Errorf("content-type: got %q", ct)
|
||||
}
|
||||
if cd := res.Header.Get("Content-Disposition"); !strings.Contains(cd, ".txt") {
|
||||
t.Errorf("content-disposition: got %q", cd)
|
||||
}
|
||||
body := readBody(t, res.Body)
|
||||
// Header lines.
|
||||
if !strings.HasPrefix(body, "# job ") {
|
||||
t.Errorf("expected '# job ...' header line; got %q", short(body))
|
||||
}
|
||||
if !strings.Contains(body, "12 log lines") {
|
||||
t.Errorf("expected '12 log lines'; got %q", short(body))
|
||||
}
|
||||
// One body line per log row — count non-comment, non-empty lines.
|
||||
var rows int
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
l := strings.TrimSpace(line)
|
||||
if l == "" || strings.HasPrefix(l, "#") {
|
||||
continue
|
||||
}
|
||||
rows++
|
||||
}
|
||||
if rows != 12 {
|
||||
t.Errorf("expected 12 body rows, got %d", rows)
|
||||
}
|
||||
// Tag check: at least one ERR row (every 5th was stderr).
|
||||
if !strings.Contains(body, " ERR ") {
|
||||
t.Errorf("expected at least one ERR row")
|
||||
}
|
||||
}
|
||||
|
||||
// TestJobLogDownloadNDJSON: each line is a self-contained JSON object.
|
||||
func TestJobLogDownloadNDJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, _ := enrolHostForWS(t, srv, st, "dl-ndjson-host")
|
||||
jobID := seedJobWithLogs(t, st, hostID, 5)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
|
||||
req, _ := stdhttp.NewRequest("GET",
|
||||
ts.URL+"/api/jobs/"+jobID+"/log.ndjson", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusOK {
|
||||
t.Fatalf("status: got %d, want 200", res.StatusCode)
|
||||
}
|
||||
if ct := res.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-ndjson") {
|
||||
t.Errorf("content-type: got %q", ct)
|
||||
}
|
||||
body := readBody(t, res.Body)
|
||||
// Each non-empty line should parse as an object with seq/ts/stream/payload.
|
||||
var seen int
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
var obj struct {
|
||||
Seq int64 `json:"seq"`
|
||||
TS string `json:"ts"`
|
||||
Stream string `json:"stream"`
|
||||
Payload string `json:"payload"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
||||
t.Fatalf("parse line %q: %v", line, err)
|
||||
}
|
||||
if obj.Seq == 0 || obj.TS == "" || obj.Stream == "" || obj.Payload == "" {
|
||||
t.Errorf("incomplete object: %+v", obj)
|
||||
}
|
||||
seen++
|
||||
}
|
||||
if seen != 5 {
|
||||
t.Errorf("parsed %d objects, want 5", seen)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJobLogDownloadNotFound: 404 for an unknown job id.
|
||||
func TestJobLogDownloadNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ts, st := rawTestServer(t)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
req, _ := stdhttp.NewRequest("GET",
|
||||
ts.URL+"/api/jobs/"+ulid.Make().String()+"/log.txt", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusNotFound {
|
||||
t.Fatalf("status: got %d, want 404", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJobLogDownloadUnauthenticated: without a session cookie, 401.
|
||||
func TestJobLogDownloadUnauthenticated(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ts, _ := rawTestServer(t)
|
||||
res, err := stdhttp.Get(ts.URL + "/api/jobs/x/log.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Fatalf("status: got %d, want 401", res.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type runNowResponse struct {
|
||||
func (s *Server) handleRunNow(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -65,7 +65,7 @@ func (s *Server) handleRunNow(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
func (s *Server) dispatchJob(ctx context.Context, user *store.User,
|
||||
hostID string, kind api.JobKind, args []string,
|
||||
) (res runNowResponse, status int, code, msg string) {
|
||||
return s.dispatchJobWithPayload(ctx, user, hostID, kind, nil, api.CommandRunPayload{
|
||||
return s.dispatchJobWithPayload(ctx, user, hostID, kind, api.CommandRunPayload{
|
||||
Kind: kind,
|
||||
Args: args,
|
||||
})
|
||||
@@ -75,12 +75,8 @@ func (s *Server) dispatchJob(ctx context.Context, user *store.User,
|
||||
// fill in structured fields (Includes/Excludes/Tag/ForgetGroups/RequiresAdminCreds)
|
||||
// — used by the per-source-group Run-now path. JobID is filled in
|
||||
// here; callers leave it zero on the input payload.
|
||||
//
|
||||
// sourceGroupID is the dedup key the alert engine will key on for
|
||||
// backup_failed. Pass non-nil for backups; nil for prune/check/unlock
|
||||
// (those are repo-scoped and dedup at host_id only).
|
||||
func (s *Server) dispatchJobWithPayload(ctx context.Context, user *store.User,
|
||||
hostID string, kind api.JobKind, sourceGroupID *string, payload api.CommandRunPayload,
|
||||
hostID string, kind api.JobKind, payload api.CommandRunPayload,
|
||||
) (res runNowResponse, status int, code, msg string) {
|
||||
if !validJobKind(kind) {
|
||||
return res, stdhttp.StatusBadRequest, "invalid_kind",
|
||||
@@ -104,13 +100,12 @@ func (s *Server) dispatchJobWithPayload(ctx context.Context, user *store.User,
|
||||
actorID = &user.ID
|
||||
}
|
||||
if err := s.deps.Store.CreateJob(ctx, store.Job{
|
||||
ID: jobID,
|
||||
HostID: host.ID,
|
||||
Kind: string(kind),
|
||||
SourceGroupID: sourceGroupID,
|
||||
ActorKind: actor,
|
||||
ActorID: actorID,
|
||||
CreatedAt: now,
|
||||
ID: jobID,
|
||||
HostID: host.ID,
|
||||
Kind: string(kind),
|
||||
ActorKind: actor,
|
||||
ActorID: actorID,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return res, stdhttp.StatusInternalServerError, "internal", ""
|
||||
}
|
||||
@@ -152,19 +147,12 @@ func (s *Server) requireUser(r *stdhttp.Request) (*store.User, bool) {
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if u.DisabledAt != nil {
|
||||
// Disabled mid-session — kill the session and reject the
|
||||
// request as if it were unauthenticated.
|
||||
_ = s.deps.Store.DeleteSession(r.Context(), auth.HashToken(c.Value))
|
||||
return nil, false
|
||||
}
|
||||
return u, true
|
||||
}
|
||||
|
||||
func validJobKind(k api.JobKind) bool {
|
||||
switch k {
|
||||
case api.JobBackup, api.JobInit, api.JobForget, api.JobPrune,
|
||||
api.JobCheck, api.JobUnlock, api.JobRestore, api.JobDiff:
|
||||
case api.JobBackup, api.JobInit, api.JobForget, api.JobPrune, api.JobCheck, api.JobUnlock:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *Server) DispatchMaintenance(ctx context.Context, decisions []maintenanc
|
||||
"host_id", d.HostID)
|
||||
continue
|
||||
}
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobForget, nil, payload)
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobForget, payload)
|
||||
if code != "" {
|
||||
slog.Warn("maintenance: forget dispatch failed",
|
||||
"host_id", d.HostID, "code", code, "msg", msg)
|
||||
@@ -65,14 +65,14 @@ func (s *Server) DispatchMaintenance(ctx context.Context, decisions []maintenanc
|
||||
continue
|
||||
}
|
||||
payload := api.CommandRunPayload{RequiresAdminCreds: true}
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobPrune, nil, payload)
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobPrune, payload)
|
||||
if code != "" {
|
||||
slog.Warn("maintenance: prune dispatch failed",
|
||||
"host_id", d.HostID, "code", code, "msg", msg)
|
||||
}
|
||||
case "check":
|
||||
payload := api.CommandRunPayload{Args: []string{strconv.Itoa(d.SubsetPct)}}
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobCheck, nil, payload)
|
||||
_, _, code, msg := s.dispatchJobWithPayload(ctx, nil, d.HostID, api.JobCheck, payload)
|
||||
if code != "" {
|
||||
slog.Warn("maintenance: check dispatch failed",
|
||||
"host_id", d.HostID, "code", code, "msg", msg)
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
// oidc_handlers.go — OIDC sign-in handlers. Public routes when oidc
|
||||
// is configured (s.deps.OIDC != nil), otherwise not mounted.
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// handleOIDCLogin generates state + PKCE pair, persists them, and
|
||||
// redirects to the IdP authorization endpoint.
|
||||
func (s *Server) handleOIDCLogin(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
state, err := oidc.RandomState()
|
||||
if err != nil {
|
||||
slog.Error("oidc login: state", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
verifier, challenge, err := oidc.PKCEPair()
|
||||
if err != nil {
|
||||
slog.Error("oidc login: pkce", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.PutOIDCState(r.Context(),
|
||||
oidc.HashState(state), verifier, time.Now().UTC()); err != nil {
|
||||
slog.Error("oidc login: persist state", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
stdhttp.Redirect(w, r, s.deps.OIDC.AuthURL(state, challenge), stdhttp.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) handleOIDCCallback(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
q := r.URL.Query()
|
||||
code := q.Get("code")
|
||||
state := q.Get("state")
|
||||
if code == "" || state == "" {
|
||||
s.oidcRedirectError(w, r, "missing_params")
|
||||
return
|
||||
}
|
||||
verifier, err := s.deps.Store.ConsumeOIDCState(r.Context(), oidc.HashState(state))
|
||||
if err != nil {
|
||||
s.oidcRedirectError(w, r, "bad_state")
|
||||
return
|
||||
}
|
||||
claims, rawIDToken, err := s.deps.OIDC.Exchange(r.Context(), code, verifier)
|
||||
if err != nil {
|
||||
slog.Warn("oidc callback: exchange", "err", err)
|
||||
s.oidcRedirectError(w, r, "exchange_failed")
|
||||
return
|
||||
}
|
||||
|
||||
uname := strings.ToLower(strings.TrimSpace(claims.PreferredUsername))
|
||||
if uname == "" {
|
||||
uname = strings.ToLower(strings.TrimSpace(claims.Email))
|
||||
}
|
||||
if uname == "" || claims.Subject == "" {
|
||||
s.oidcRedirectError(w, r, "missing_claims")
|
||||
return
|
||||
}
|
||||
|
||||
role := s.deps.OIDC.MapRole(claims.Roles)
|
||||
if role == "" {
|
||||
_ = s.auditOIDCBlocked(r, claims, "no_role_match")
|
||||
s.oidcRedirectError(w, r, "no_role_match")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Returning OIDC user — refresh role + email + last_login.
|
||||
existing, err := s.deps.Store.GetUserByOIDCSubject(r.Context(), claims.Subject)
|
||||
if err == nil {
|
||||
if existing.DisabledAt != nil {
|
||||
s.oidcRedirectError(w, r, "user_disabled")
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.SetUserRole(r.Context(), existing.ID, store.Role(role))
|
||||
_ = s.deps.Store.SetUserEmail(r.Context(), existing.ID, claims.Email)
|
||||
_ = s.deps.Store.MarkUserLogin(r.Context(), existing.ID, now)
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: &existing.ID, Actor: "user",
|
||||
Action: "user.oidc_login", TargetKind: ptr("user"),
|
||||
TargetID: &existing.ID, TS: now,
|
||||
})
|
||||
s.oidcDropSessionAndRedirect(w, r, existing.ID, rawIDToken, now)
|
||||
return
|
||||
} else if !errors.Is(err, store.ErrNotFound) {
|
||||
slog.Error("oidc callback: lookup by sub", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// New OIDC user — first check the username doesn't collide with
|
||||
// a local user.
|
||||
if _, err := s.deps.Store.GetUserByUsername(r.Context(), uname); err == nil {
|
||||
_ = s.auditOIDCBlocked(r, claims, "username_taken")
|
||||
s.oidcRedirectError(w, r, "username_taken")
|
||||
return
|
||||
} else if !errors.Is(err, store.ErrNotFound) {
|
||||
slog.Error("oidc callback: lookup by username", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// JIT-provision.
|
||||
id := ulid.Make().String()
|
||||
var emailPtr *string
|
||||
if claims.Email != "" {
|
||||
em := strings.ToLower(claims.Email)
|
||||
emailPtr = &em
|
||||
}
|
||||
sub := claims.Subject
|
||||
if err := s.deps.Store.CreateUser(r.Context(), store.User{
|
||||
ID: id, Username: uname, PasswordHash: "",
|
||||
Role: store.Role(role), Email: emailPtr,
|
||||
AuthSource: "oidc", OIDCSubject: &sub,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
slog.Error("oidc callback: provision", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = s.deps.Store.MarkUserLogin(r.Context(), id, now)
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: &id, Actor: "user",
|
||||
Action: "user.created", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: now,
|
||||
Payload: jsonMust(map[string]any{"auth_source": "oidc"}),
|
||||
})
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: &id, Actor: "user",
|
||||
Action: "user.oidc_login", TargetKind: ptr("user"), TargetID: &id,
|
||||
TS: now,
|
||||
})
|
||||
s.oidcDropSessionAndRedirect(w, r, id, rawIDToken, now)
|
||||
}
|
||||
|
||||
func (s *Server) oidcDropSessionAndRedirect(w stdhttp.ResponseWriter, r *stdhttp.Request, userID, idToken string, now time.Time) {
|
||||
rawSession, err := auth.NewToken()
|
||||
if err != nil {
|
||||
slog.Error("oidc: session token", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
hashed := auth.HashToken(rawSession)
|
||||
if err := s.deps.Store.CreateSession(r.Context(), store.Session{
|
||||
ID: hashed, UserID: userID, CreatedAt: now,
|
||||
ExpiresAt: now.Add(8 * time.Hour),
|
||||
IDToken: idToken,
|
||||
}, hashed); err != nil {
|
||||
slog.Error("oidc: create session", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
stdhttp.SetCookie(w, &stdhttp.Cookie{
|
||||
Name: sessionCookieName, Value: rawSession,
|
||||
Path: "/", HttpOnly: true,
|
||||
SameSite: stdhttp.SameSiteLaxMode,
|
||||
Secure: s.deps.Cfg.CookieSecure,
|
||||
Expires: now.Add(8 * time.Hour),
|
||||
})
|
||||
stdhttp.Redirect(w, r, "/", stdhttp.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) oidcRedirectError(w stdhttp.ResponseWriter, r *stdhttp.Request, code string) {
|
||||
stdhttp.Redirect(w, r, "/login?oidc_error="+code, stdhttp.StatusSeeOther)
|
||||
}
|
||||
|
||||
// auditOIDCBlocked records a failed sign-in. user_id is nil because
|
||||
// no row was created; the IdP subject + reason go in the payload so
|
||||
// admin can correlate.
|
||||
func (s *Server) auditOIDCBlocked(r *stdhttp.Request, claims *oidc.Claims, reason string) error {
|
||||
return s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(), UserID: nil, Actor: "system",
|
||||
Action: "user.oidc_login_blocked", TargetKind: ptr("user"),
|
||||
TargetID: nil, TS: time.Now().UTC(),
|
||||
Payload: jsonMust(map[string]any{
|
||||
"sub": claims.Subject,
|
||||
"username": claims.PreferredUsername,
|
||||
"reason": reason,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
// jsonMust marshals to json.RawMessage; on error returns nil so the
|
||||
// audit row still lands without the payload (best-effort).
|
||||
func jsonMust(v any) json.RawMessage {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return json.RawMessage(b)
|
||||
}
|
||||
@@ -1,293 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc/oidctest"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// newTestServerWithOIDC returns a Server wired to a stub IdP.
|
||||
// Returned ts is the httptest.Server fronting the actual server;
|
||||
// stub is the IdP for minting codes / configuring claims.
|
||||
func newTestServerWithOIDC(t *testing.T) (*Server, *httptest.Server, *oidctest.StubIdP) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
if err := crypto.GenerateKeyFile(keyPath); err != nil {
|
||||
t.Fatalf("genkey: %v", err)
|
||||
}
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
|
||||
stub := oidctest.New(t)
|
||||
cfg := &config.OIDCConfig{
|
||||
Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x",
|
||||
Scopes: []string{"openid"}, RoleClaim: "groups",
|
||||
RoleMapping: map[string]string{
|
||||
"rm-admins": "admin",
|
||||
"rm-operators": "operator",
|
||||
"rm-viewers": "viewer",
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
oidcClient, err := oidc.New(ctx, cfg, "http://test")
|
||||
if err != nil {
|
||||
t.Fatalf("oidc client: %v", err)
|
||||
}
|
||||
|
||||
deps := Deps{
|
||||
Cfg: config.Config{Listen: ":0", DataDir: dir, SecretKeyFile: keyPath, BaseURL: "http://test"},
|
||||
Store: st,
|
||||
AEAD: aead,
|
||||
OIDC: oidcClient,
|
||||
}
|
||||
s := New(deps)
|
||||
ts := httptest.NewServer(s.srv.Handler)
|
||||
t.Cleanup(ts.Close)
|
||||
return s, ts, stub
|
||||
}
|
||||
|
||||
func TestOIDCLoginRedirectsToIdP(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, _ := newTestServerWithOIDC(t)
|
||||
c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
||||
return stdhttp.ErrUseLastResponse
|
||||
}}
|
||||
res, err := c.Get(ts.URL + "/auth/oidc/login")
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusSeeOther {
|
||||
t.Errorf("status: got %d want 303", res.StatusCode)
|
||||
}
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "code_challenge=") || !strings.Contains(loc, "state=") {
|
||||
t.Errorf("location: %q", loc)
|
||||
}
|
||||
_ = srv
|
||||
}
|
||||
|
||||
// runCallback drives the auth code flow against the stub: kicks off
|
||||
// /auth/oidc/login (capturing the state), mints a code at the stub
|
||||
// with the given claims, then GETs /auth/oidc/callback. Returns the
|
||||
// final response.
|
||||
func runCallback(t *testing.T, ts *httptest.Server, stub *oidctest.StubIdP, claims map[string]any) *stdhttp.Response {
|
||||
t.Helper()
|
||||
jar, _ := cookiejar.New(nil)
|
||||
c := &stdhttp.Client{Jar: jar, CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
||||
return stdhttp.ErrUseLastResponse
|
||||
}}
|
||||
res, err := c.Get(ts.URL + "/auth/oidc/login")
|
||||
if err != nil {
|
||||
t.Fatalf("login: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
authURL, _ := url.Parse(res.Header.Get("Location"))
|
||||
state := authURL.Query().Get("state")
|
||||
|
||||
code := stub.MintCode(claims)
|
||||
res, err = c.Get(ts.URL + "/auth/oidc/callback?code=" + code + "&state=" + state)
|
||||
if err != nil {
|
||||
t.Fatalf("callback: %v", err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func TestOIDCCallbackHappyPathAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, stub := newTestServerWithOIDC(t)
|
||||
res := runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "admin-sub",
|
||||
"preferred_username": "alice",
|
||||
"email": "alice@example.com",
|
||||
"groups": []string{"rm-admins"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusSeeOther || res.Header.Get("Location") != "/" {
|
||||
t.Errorf("status: %d Location: %q", res.StatusCode, res.Header.Get("Location"))
|
||||
}
|
||||
u, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "admin-sub")
|
||||
if err != nil || u.AuthSource != "oidc" || u.Role != "admin" || u.Username != "alice" {
|
||||
t.Errorf("user: %+v err: %v", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCCallbackNoRoleMatchDeny(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ts, stub := newTestServerWithOIDC(t)
|
||||
res := runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "other-sub",
|
||||
"preferred_username": "bob",
|
||||
"groups": []string{"something-else"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusSeeOther {
|
||||
t.Errorf("status: got %d want 303", res.StatusCode)
|
||||
}
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "oidc_error=no_role_match") {
|
||||
t.Errorf("location: %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCCallbackUsernameCollision(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, stub := newTestServerWithOIDC(t)
|
||||
if err := srv.deps.Store.CreateUser(t.Context(), store.User{
|
||||
ID: "local-alice", Username: "alice", PasswordHash: "x",
|
||||
Role: store.RoleViewer, CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
|
||||
res := runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "remote-sub",
|
||||
"preferred_username": "alice",
|
||||
"groups": []string{"rm-admins"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
defer res.Body.Close()
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "oidc_error=username_taken") {
|
||||
t.Errorf("location: %q", loc)
|
||||
}
|
||||
if _, err := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "remote-sub"); err == nil {
|
||||
t.Error("collision should not have provisioned a user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCCallbackReturningUserRefreshesRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, stub := newTestServerWithOIDC(t)
|
||||
res := runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "carol-sub",
|
||||
"preferred_username": "carol",
|
||||
"groups": []string{"rm-operators"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
res.Body.Close()
|
||||
res = runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "carol-sub",
|
||||
"preferred_username": "carol",
|
||||
"groups": []string{"rm-admins"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
res.Body.Close()
|
||||
u, _ := srv.deps.Store.GetUserByOIDCSubject(t.Context(), "carol-sub")
|
||||
if u.Role != "admin" {
|
||||
t.Errorf("role refresh: got %q want admin", u.Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCLogoutRedirectsToEndSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, stub := newTestServerWithOIDC(t)
|
||||
endSessionURL := stub.URL() + "/logout-end"
|
||||
stub.SetEndSessionEndpoint(endSessionURL)
|
||||
|
||||
// Rebuild the OIDC client because end_session_endpoint is read at
|
||||
// New() time from the discovery doc.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
cfg := &config.OIDCConfig{
|
||||
Issuer: stub.URL(), ClientID: "test-client", ClientSecret: "x",
|
||||
Scopes: []string{"openid"}, RoleClaim: "groups",
|
||||
RoleMapping: map[string]string{"rm-admins": "admin"},
|
||||
}
|
||||
newClient, err := oidc.New(ctx, cfg, "http://test")
|
||||
if err != nil {
|
||||
t.Fatalf("rebuild client: %v", err)
|
||||
}
|
||||
srv.deps.OIDC = newClient
|
||||
|
||||
// Sign in via the OIDC flow.
|
||||
res := runCallback(t, ts, stub, map[string]any{
|
||||
"sub": "logout-sub",
|
||||
"preferred_username": "lo",
|
||||
"groups": []string{"rm-admins"},
|
||||
"aud": "test-client",
|
||||
})
|
||||
res.Body.Close()
|
||||
cookies := res.Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("expected session cookie after sign-in")
|
||||
}
|
||||
sessionCookie := cookies[0]
|
||||
|
||||
// POST /logout — should 303 to the end_session endpoint with
|
||||
// id_token_hint + post_logout_redirect_uri.
|
||||
c := &stdhttp.Client{CheckRedirect: func(*stdhttp.Request, []*stdhttp.Request) error {
|
||||
return stdhttp.ErrUseLastResponse
|
||||
}}
|
||||
req, _ := stdhttp.NewRequest("POST", ts.URL+"/logout", nil)
|
||||
req.AddCookie(sessionCookie)
|
||||
res, err = c.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("logout: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusSeeOther {
|
||||
t.Errorf("status: got %d want 303", res.StatusCode)
|
||||
}
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/logout-end") {
|
||||
t.Errorf("location not at end_session: %q", loc)
|
||||
}
|
||||
if !strings.Contains(loc, "id_token_hint=") {
|
||||
t.Errorf("location missing id_token_hint: %q", loc)
|
||||
}
|
||||
if !strings.Contains(loc, "post_logout_redirect_uri=") {
|
||||
t.Errorf("location missing post_logout_redirect_uri: %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalLoginRejectsOIDCUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, urlBase := newTestServer(t, false)
|
||||
uid := "u-oidc"
|
||||
sub := "sub-x"
|
||||
if err := srv.deps.Store.CreateUser(t.Context(), store.User{
|
||||
ID: uid, Username: "ouser", PasswordHash: "",
|
||||
Role: store.RoleOperator, CreatedAt: time.Now().UTC(),
|
||||
AuthSource: "oidc", OIDCSubject: &sub,
|
||||
}); err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"username": "ouser", "password": "anything",
|
||||
})
|
||||
res, err := stdhttp.Post(urlBase+"/api/auth/login",
|
||||
"application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("post: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("status: got %d want 401", res.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func drainUntil(t *testing.T, c *websocket.Conn, wantType api.MessageType) api.E
|
||||
return api.Envelope{}
|
||||
}
|
||||
|
||||
// enrolHostForWS pre-enrols a host with bound repo creds so the server
|
||||
// enrolHostForWS pre-enrolls a host with bound repo creds so the server
|
||||
// will treat it as ready to receive command.run.
|
||||
func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (hostID, token string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -145,12 +145,7 @@ func TestDrainPendingDispatchesOnReconnect(t *testing.T) {
|
||||
t.Errorf("backup tag: %q", got.Tag)
|
||||
}
|
||||
|
||||
// Pending row should be gone. Poll briefly: the drain goroutine
|
||||
// sends command.run via conn.Send and only then calls
|
||||
// DeletePendingRun. Reading the envelope off the wire above proves
|
||||
// the send happened, but the delete runs after that on the drain
|
||||
// goroutine — small window where the count is still 1.
|
||||
waitForPendingCount(t, st, hostID, 0, 2*time.Second)
|
||||
// Pending row should be gone.
|
||||
if n := countPendingForHost(t, st, hostID); n != 0 {
|
||||
t.Errorf("pending rows after drain: got %d, want 0", n)
|
||||
}
|
||||
@@ -506,12 +501,12 @@ func TestEnqueueOnDispatchFailure(t *testing.T) {
|
||||
func TestDrainPendingSerializesPerHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, token := enrolHostForWS(t, srv, st, "serialise-host")
|
||||
hostID, token := enrolHostForWS(t, srv, st, "serialize-host")
|
||||
gid, sid := seedSchedAndGroup(t, st, hostID, 10)
|
||||
|
||||
// Connect the agent so DrainPending can dispatch.
|
||||
c := agentDial(t, srv, ts, hostID, token)
|
||||
sendHello(t, c, "serialise-host")
|
||||
sendHello(t, c, "serialize-host")
|
||||
// Drain the on-hello goroutine's pass first (no pending rows yet),
|
||||
// then wait for the schedule.set so the connection is fully settled.
|
||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
// pending_ws.go — /ws/agent/pending and the admin accept/reject
|
||||
// endpoints for the announce-and-approve enrolment flow (P2-18b).
|
||||
//
|
||||
// Flow:
|
||||
// 1. Agent has previously called POST /api/agents/announce, which
|
||||
// returned its pending_id + fingerprint. Agent persists the
|
||||
// keypair locally.
|
||||
// 2. Agent connects to /ws/agent/pending?pending_id=… (no auth).
|
||||
// Server reads the row, generates a 32-byte nonce, sends it.
|
||||
// 3. Agent signs the nonce with its Ed25519 private key, sends the
|
||||
// signature back. Server verifies; close on bad sig.
|
||||
// 4. The connection sits open; the agent reads but doesn't write.
|
||||
// 5. Admin clicks Accept: POST /api/pending-hosts/{id}/accept with
|
||||
// the same repo-creds form the token-mint flow uses. Server
|
||||
// mints a Host row + bearer + encrypted creds, pushes one
|
||||
// `enrolled` message down the open socket, closes cleanly.
|
||||
// 6. Admin clicks Reject: socket closes with code 4001.
|
||||
//
|
||||
// Hub: a process-local in-memory map of pending_id → live conn so
|
||||
// the accept/reject handlers can find the right socket. Sole
|
||||
// instance lives on Server.pendingHub.
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
stdhttp "net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// pendingConn is a single live /ws/agent/pending session. The accept
|
||||
// handler sends the enrolment message via Send and closes the socket;
|
||||
// the WS read loop is just waiting for that close.
|
||||
type pendingConn struct {
|
||||
conn *websocket.Conn
|
||||
pendingID string
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// pendingHub is the in-memory map of pending_id → live socket.
|
||||
type pendingHub struct {
|
||||
mu sync.Mutex
|
||||
conns map[string]*pendingConn
|
||||
}
|
||||
|
||||
func newPendingHub() *pendingHub {
|
||||
return &pendingHub{conns: map[string]*pendingConn{}}
|
||||
}
|
||||
|
||||
func (h *pendingHub) register(pc *pendingConn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
// Replace any existing socket for the same pending_id (an agent
|
||||
// reconnected) — close the old one cleanly first so its goroutine
|
||||
// can exit.
|
||||
if old, ok := h.conns[pc.pendingID]; ok {
|
||||
_ = old.conn.Close(websocket.StatusNormalClosure, "superseded")
|
||||
close(old.closed)
|
||||
}
|
||||
h.conns[pc.pendingID] = pc
|
||||
}
|
||||
|
||||
func (h *pendingHub) unregister(pendingID string, pc *pendingConn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if cur, ok := h.conns[pendingID]; ok && cur == pc {
|
||||
delete(h.conns, pendingID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *pendingHub) get(pendingID string) *pendingConn {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.conns[pendingID]
|
||||
}
|
||||
|
||||
// nonceMessage is what the server sends first on /ws/agent/pending.
|
||||
type nonceMessage struct {
|
||||
Type string `json:"type"` // "nonce"
|
||||
Nonce string `json:"nonce"` // base64
|
||||
}
|
||||
|
||||
// signedNonceMessage is what the agent sends back.
|
||||
type signedNonceMessage struct {
|
||||
Type string `json:"type"` // "signed_nonce"
|
||||
Signature string `json:"signature"` // base64
|
||||
}
|
||||
|
||||
// enrolledMessage is what the server sends on accept. The agent
|
||||
// persists the bearer to agent.yaml and exits announce mode.
|
||||
type enrolledMessage struct {
|
||||
Type string `json:"type"` // "enrolled"
|
||||
HostID string `json:"host_id"`
|
||||
Bearer string `json:"bearer"`
|
||||
ServerID string `json:"server_id,omitempty"`
|
||||
}
|
||||
|
||||
// handlePendingWS upgrades the WS, runs the nonce-sign handshake,
|
||||
// registers the conn in the hub, and blocks until the conn is
|
||||
// closed (by accept/reject or by the agent disconnecting).
|
||||
func (s *Server) handlePendingWS(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
pendingID := r.URL.Query().Get("pending_id")
|
||||
if pendingID == "" {
|
||||
stdhttp.Error(w, "missing pending_id", stdhttp.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
|
||||
if err != nil {
|
||||
stdhttp.Error(w, "pending host not found", stdhttp.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if time.Now().UTC().After(row.ExpiresAt) {
|
||||
stdhttp.Error(w, "pending host expired", stdhttp.StatusGone)
|
||||
return
|
||||
}
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
// Same-origin defaults are safe: the agent isn't a browser.
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("pending ws: accept", "pending_id", pendingID, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate + send nonce.
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, "nonce gen")
|
||||
return
|
||||
}
|
||||
nm := nonceMessage{Type: "nonce", Nonce: base64.StdEncoding.EncodeToString(nonce)}
|
||||
raw, _ := json.Marshal(nm)
|
||||
wctx, wcancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
if err := conn.Write(wctx, websocket.MessageText, raw); err != nil {
|
||||
wcancel()
|
||||
_ = conn.Close(websocket.StatusInternalError, "send nonce")
|
||||
return
|
||||
}
|
||||
wcancel()
|
||||
|
||||
// Read signed nonce back.
|
||||
rctx, rcancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
mt, body, err := conn.Read(rctx)
|
||||
rcancel()
|
||||
if err != nil || mt != websocket.MessageText {
|
||||
_ = conn.Close(websocket.StatusPolicyViolation, "no signed nonce")
|
||||
return
|
||||
}
|
||||
var sig signedNonceMessage
|
||||
if err := json.Unmarshal(body, &sig); err != nil || sig.Type != "signed_nonce" {
|
||||
_ = conn.Close(websocket.StatusPolicyViolation, "bad signed nonce shape")
|
||||
return
|
||||
}
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sig.Signature)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusPolicyViolation, "bad signature b64")
|
||||
return
|
||||
}
|
||||
if !ed25519.Verify(row.PublicKey, nonce, sigBytes) {
|
||||
_ = conn.Close(websocket.StatusPolicyViolation, "signature does not verify")
|
||||
return
|
||||
}
|
||||
|
||||
// Touch the row so the dashboard knows the agent is live.
|
||||
_ = s.deps.Store.TouchPendingHost(context.Background(), pendingID, time.Now().UTC())
|
||||
|
||||
// Register and block until close.
|
||||
pc := &pendingConn{conn: conn, pendingID: pendingID, closed: make(chan struct{})}
|
||||
s.pendingHub.register(pc)
|
||||
defer s.pendingHub.unregister(pendingID, pc)
|
||||
|
||||
// Read loop: we don't expect any further frames from the agent.
|
||||
// If the agent closes, we exit.
|
||||
go func() {
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
_, _, err := conn.Read(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
close(pc.closed)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
<-pc.closed
|
||||
}
|
||||
|
||||
// acceptForm is the admin form for POST /api/pending-hosts/{id}/accept.
|
||||
// repo_password may be omitted only when the host already has admin-
|
||||
// supplied creds elsewhere — we don't currently model that. For now,
|
||||
// require all three.
|
||||
type acceptForm struct {
|
||||
RepoURL string `json:"repo_url"`
|
||||
RepoUsername string `json:"repo_username"`
|
||||
RepoPassword string `json:"repo_password"`
|
||||
}
|
||||
|
||||
// handleAcceptPendingHost mints a real Host row + bearer + encrypted
|
||||
// repo creds and pushes the bearer down the agent's open pending WS.
|
||||
// Admin-auth required.
|
||||
func (s *Server) handleAcceptPendingHost(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
pendingID := chi.URLParam(r, "id")
|
||||
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "pending_not_found", "")
|
||||
return
|
||||
}
|
||||
pc := s.pendingHub.get(pendingID)
|
||||
if pc == nil {
|
||||
writeJSONError(w, stdhttp.StatusConflict, "agent_not_connected",
|
||||
"the pending agent is not currently connected; ask it to retry")
|
||||
return
|
||||
}
|
||||
|
||||
var form acceptForm
|
||||
// Accept either JSON or form-urlencoded so HTMX-style POST works.
|
||||
if r.Header.Get("Content-Type") == "application/json" {
|
||||
if err := json.NewDecoder(r.Body).Decode(&form); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "invalid_json", err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "bad_form", err.Error())
|
||||
return
|
||||
}
|
||||
form.RepoURL = r.PostForm.Get("repo_url")
|
||||
form.RepoUsername = r.PostForm.Get("repo_username")
|
||||
form.RepoPassword = r.PostForm.Get("repo_password")
|
||||
}
|
||||
if form.RepoURL == "" || form.RepoPassword == "" {
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_field",
|
||||
"repo_url and repo_password are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Mint persistent bearer + Host row.
|
||||
hostID := ulid.Make().String()
|
||||
token, err := auth.NewToken()
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
host := store.Host{
|
||||
ID: hostID, Name: row.Hostname, OS: row.OS, Arch: row.Arch,
|
||||
AgentVersion: row.AgentVersion, ResticVersion: row.ResticVersion,
|
||||
EnrolledAt: time.Now().UTC(),
|
||||
}
|
||||
if err := s.deps.Store.CreateHost(r.Context(), host, auth.HashToken(token), ""); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
// Encrypt + persist repo creds.
|
||||
enc, err := s.encryptRepoCreds(repoCredsBlob(form), []byte("host:"+hostID))
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, enc); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
// Drop the pending row.
|
||||
if err := s.deps.Store.DeletePendingHost(r.Context(), pendingID); err != nil {
|
||||
slog.Warn("accept pending: delete row", "pending_id", pendingID, "err", err)
|
||||
}
|
||||
// Push enrolled message + close the pending WS.
|
||||
enrolled := enrolledMessage{Type: "enrolled", HostID: hostID, Bearer: token}
|
||||
raw, _ := json.Marshal(enrolled)
|
||||
wctx, wcancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
if err := pc.conn.Write(wctx, websocket.MessageText, raw); err != nil {
|
||||
slog.Warn("accept pending: write enrolled", "pending_id", pendingID, "err", err)
|
||||
}
|
||||
wcancel()
|
||||
_ = pc.conn.Close(websocket.StatusNormalClosure, "accepted")
|
||||
|
||||
// Audit.
|
||||
uid := user.ID
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: &uid,
|
||||
Actor: "user",
|
||||
Action: "host.accept_pending",
|
||||
TargetKind: ptr("host"),
|
||||
TargetID: &hostID,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
|
||||
writeJSON(w, stdhttp.StatusOK, map[string]any{
|
||||
"host_id": hostID,
|
||||
"fingerprint": row.Fingerprint,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRejectPendingHost deletes the pending row and closes any
|
||||
// open WS for it. Admin-auth required.
|
||||
func (s *Server) handleRejectPendingHost(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
pendingID := chi.URLParam(r, "id")
|
||||
row, err := s.deps.Store.GetPendingHost(r.Context(), pendingID)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
if pc := s.pendingHub.get(pendingID); pc != nil {
|
||||
_ = pc.conn.Close(4001, "rejected")
|
||||
}
|
||||
if err := s.deps.Store.DeletePendingHost(r.Context(), pendingID); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", err.Error())
|
||||
return
|
||||
}
|
||||
uid := user.ID
|
||||
_ = s.deps.Store.AppendAudit(r.Context(), store.AuditEntry{
|
||||
ID: ulid.Make().String(),
|
||||
UserID: &uid,
|
||||
Actor: "user",
|
||||
Action: "host.reject_pending",
|
||||
TargetKind: ptr("pending_host"),
|
||||
TargetID: &row.ID,
|
||||
TS: time.Now().UTC(),
|
||||
})
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
// pending_ws_test.go — end-to-end test of the announce → pending WS
|
||||
// → admin accept → bearer push round trip (P2-18b/c).
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// TestPendingWSNonceSignAcceptFlow: simulate an agent. Announce →
|
||||
// open pending WS → sign nonce → admin accept (with repo creds) →
|
||||
// expect 'enrolled' message with bearer.
|
||||
func TestPendingWSNonceSignAcceptFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ed25519: %v", err)
|
||||
}
|
||||
|
||||
// Pre-seed pending row directly (bypass the announce HTTP path
|
||||
// since announce coverage lives in announce_test.go).
|
||||
pendingID := ulid.Make().String()
|
||||
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
|
||||
ID: pendingID, Hostname: "ann-host", OS: "linux", Arch: "amd64",
|
||||
AgentVersion: "1.0", ResticVersion: "0.17",
|
||||
PublicKey: pub, Fingerprint: store.FingerprintForKey(pub),
|
||||
AnnouncedFromIP: "127.0.0.1",
|
||||
FirstSeenAt: time.Now().UTC(),
|
||||
LastSeenAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
|
||||
// Open the pending WS.
|
||||
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID
|
||||
dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer dialCancel()
|
||||
c, res, err := websocket.Dial(dialCtx, wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial pending ws: %v", err)
|
||||
}
|
||||
if res != nil && res.Body != nil {
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
t.Cleanup(func() { _ = c.CloseNow() })
|
||||
|
||||
// Read nonce.
|
||||
rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, raw, err := c.Read(rctx)
|
||||
rcancel()
|
||||
if err != nil {
|
||||
t.Fatalf("read nonce: %v", err)
|
||||
}
|
||||
var nm nonceMessage
|
||||
if err := json.Unmarshal(raw, &nm); err != nil {
|
||||
t.Fatalf("unmarshal nonce: %v", err)
|
||||
}
|
||||
nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce)
|
||||
|
||||
// Sign + reply.
|
||||
sig := ed25519.Sign(priv, nonce)
|
||||
reply, _ := json.Marshal(signedNonceMessage{
|
||||
Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig),
|
||||
})
|
||||
wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
if err := c.Write(wctx, websocket.MessageText, reply); err != nil {
|
||||
wcancel()
|
||||
t.Fatalf("write signed nonce: %v", err)
|
||||
}
|
||||
wcancel()
|
||||
|
||||
// Wait briefly so the server's hub.register completes before we
|
||||
// fire accept.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if srv.pendingHub.get(pendingID) != nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Admin POST accept (form-encoded, with cookie).
|
||||
cookie := loginAsAdmin(t, st)
|
||||
form := url.Values{
|
||||
"repo_url": {"rest:http://r/x"},
|
||||
"repo_username": {"u"},
|
||||
"repo_password": {"p"},
|
||||
}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/api/pending-hosts/"+pendingID+"/accept",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(cookie)
|
||||
resAccept, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("accept: %v", err)
|
||||
}
|
||||
defer resAccept.Body.Close()
|
||||
if resAccept.StatusCode != stdhttp.StatusOK {
|
||||
t.Fatalf("accept status: %d", resAccept.StatusCode)
|
||||
}
|
||||
|
||||
// Expect 'enrolled' message + close.
|
||||
rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, raw2, err := c.Read(rctx2)
|
||||
rcancel2()
|
||||
if err != nil {
|
||||
t.Fatalf("read enrolled: %v", err)
|
||||
}
|
||||
var em enrolledMessage
|
||||
if err := json.Unmarshal(raw2, &em); err != nil {
|
||||
t.Fatalf("unmarshal enrolled: %v", err)
|
||||
}
|
||||
if em.Type != "enrolled" || em.Bearer == "" || em.HostID == "" {
|
||||
t.Fatalf("enrolled payload bad: %+v", em)
|
||||
}
|
||||
|
||||
// Pending row should be gone.
|
||||
if _, err := st.GetPendingHost(context.Background(), pendingID); err == nil {
|
||||
t.Error("pending row should have been deleted on accept")
|
||||
}
|
||||
// Real host row should exist.
|
||||
if _, err := st.GetHost(context.Background(), em.HostID); err != nil {
|
||||
t.Errorf("host row not created: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPendingWSBadSignatureClosed: server closes the WS when the
|
||||
// signature does not verify against the row's public key.
|
||||
func TestPendingWSBadSignatureClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
_ = srv
|
||||
|
||||
// Two distinct keypairs — agent signs with the wrong one.
|
||||
pubReal, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||
_, privAttacker, _ := ed25519.GenerateKey(rand.Reader)
|
||||
|
||||
pendingID := ulid.Make().String()
|
||||
if err := st.CreatePendingHost(context.Background(), &store.PendingHost{
|
||||
ID: pendingID, Hostname: "bad-host", OS: "linux", Arch: "amd64",
|
||||
PublicKey: pubReal, Fingerprint: store.FingerprintForKey(pubReal),
|
||||
AnnouncedFromIP: "127.0.0.1",
|
||||
FirstSeenAt: time.Now().UTC(),
|
||||
LastSeenAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/agent/pending?pending_id=" + pendingID
|
||||
dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer dialCancel()
|
||||
c, res, err := websocket.Dial(dialCtx, wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
if res != nil && res.Body != nil {
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
defer func() { _ = c.CloseNow() }()
|
||||
|
||||
// Read nonce.
|
||||
rctx, rcancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, raw, _ := c.Read(rctx)
|
||||
rcancel()
|
||||
var nm nonceMessage
|
||||
_ = json.Unmarshal(raw, &nm)
|
||||
nonce, _ := base64.StdEncoding.DecodeString(nm.Nonce)
|
||||
|
||||
// Sign with the wrong key.
|
||||
sig := ed25519.Sign(privAttacker, nonce)
|
||||
reply, _ := json.Marshal(signedNonceMessage{
|
||||
Type: "signed_nonce", Signature: base64.StdEncoding.EncodeToString(sig),
|
||||
})
|
||||
wctx, wcancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_ = c.Write(wctx, websocket.MessageText, reply)
|
||||
wcancel()
|
||||
|
||||
// Server should close. Read until error.
|
||||
rctx2, rcancel2 := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = c.Read(rctx2)
|
||||
rcancel2()
|
||||
if err == nil {
|
||||
t.Fatal("expected ws to close on bad signature")
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
stdhttp "net/http"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ui"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// rank maps each role to a numeric tier so 'A is at least B' becomes
|
||||
// 'rank[A] >= rank[B] && both are known'. Unknown roles return 0 →
|
||||
// fail-closed against either argument.
|
||||
var roleRank = map[store.Role]int{
|
||||
store.RoleViewer: 1,
|
||||
store.RoleOperator: 2,
|
||||
store.RoleAdmin: 3,
|
||||
}
|
||||
|
||||
// roleAtLeast reports whether `have` meets or exceeds `min` in the
|
||||
// admin > operator > viewer hierarchy. Either side being an unknown
|
||||
// role returns false.
|
||||
func roleAtLeast(have, min store.Role) bool {
|
||||
h, hok := roleRank[have]
|
||||
m, mok := roleRank[min]
|
||||
if !hok || !mok {
|
||||
return false
|
||||
}
|
||||
return h >= m
|
||||
}
|
||||
|
||||
// requireRole returns chi middleware that 403s any request whose
|
||||
// session-resolved user doesn't meet the minimum role. Unauthenticated
|
||||
// requests return 401 (JSON) or 303 → /login (HTML) so the caller
|
||||
// gets a usable error rather than a confusing 403.
|
||||
//
|
||||
// The middleware re-reads the user row on every request — by the time
|
||||
// you read this you might be tempted to cache; don't. SQLite's WAL
|
||||
// makes the lookup cheap and admin-driven changes (disable, role
|
||||
// change) need to land immediately.
|
||||
func (s *Server) requireRole(min store.Role) func(stdhttp.Handler) stdhttp.Handler {
|
||||
return func(next stdhttp.Handler) stdhttp.Handler {
|
||||
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
u, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
if isAPIPath(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
return
|
||||
}
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if !roleAtLeast(u.Role, min) {
|
||||
if isAPIPath(r) {
|
||||
writeJSONError(w, stdhttp.StatusForbidden, "insufficient_role", "")
|
||||
return
|
||||
}
|
||||
renderForbiddenHTML(s, w, r, u, min)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// isAPIPath reports whether the path lives under /api/. Lets one
|
||||
// middleware return JSON or HTML appropriately without two near-
|
||||
// identical wrappers.
|
||||
func isAPIPath(r *stdhttp.Request) bool {
|
||||
p := r.URL.Path
|
||||
return len(p) >= 5 && p[:5] == "/api/"
|
||||
}
|
||||
|
||||
// renderForbiddenHTML emits a small "you don't have permission"
|
||||
// panel inside the chrome so the user keeps their nav and can
|
||||
// move away to a page they can see.
|
||||
func renderForbiddenHTML(s *Server, w stdhttp.ResponseWriter, r *stdhttp.Request, u *store.User, min store.Role) {
|
||||
w.WriteHeader(stdhttp.StatusForbidden)
|
||||
view := s.baseView(r, &ui.User{ID: u.ID, Username: u.Username, Role: string(u.Role)})
|
||||
view.Title = "Forbidden · restic-manager"
|
||||
view.Page = struct {
|
||||
Required string
|
||||
Have string
|
||||
}{Required: string(min), Have: string(u.Role)}
|
||||
if err := s.deps.UI.Render(w, "forbidden", view); err != nil {
|
||||
_, _ = w.Write([]byte("403 Forbidden — your role does not permit this page."))
|
||||
}
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func TestRoleAtLeast(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
have store.Role
|
||||
min store.Role
|
||||
want bool
|
||||
}{
|
||||
{store.RoleViewer, store.RoleViewer, true},
|
||||
{store.RoleOperator, store.RoleViewer, true},
|
||||
{store.RoleAdmin, store.RoleViewer, true},
|
||||
{store.RoleAdmin, store.RoleOperator, true},
|
||||
{store.RoleAdmin, store.RoleAdmin, true},
|
||||
{store.RoleViewer, store.RoleOperator, false},
|
||||
{store.RoleViewer, store.RoleAdmin, false},
|
||||
{store.RoleOperator, store.RoleAdmin, false},
|
||||
{store.Role("nonsense"), store.RoleViewer, false},
|
||||
{store.RoleAdmin, store.Role("nonsense"), false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := roleAtLeast(c.have, c.min)
|
||||
if got != c.want {
|
||||
t.Errorf("have=%q min=%q: got %v want %v", c.have, c.min, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleViewerAdmits(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newTestServer(t, false)
|
||||
uid := makeUser(t, srv, "viewer1", store.RoleViewer)
|
||||
cookie := loginAs(t, srv, uid)
|
||||
|
||||
mid := srv.requireRole(store.RoleViewer)
|
||||
h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil)
|
||||
req.AddCookie(cookie)
|
||||
h.ServeHTTP(rr, req)
|
||||
if rr.Code != stdhttp.StatusOK {
|
||||
t.Errorf("status: got %d want 200", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleViewerRejectedFromOperator(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newTestServer(t, false)
|
||||
uid := makeUser(t, srv, "viewer2", store.RoleViewer)
|
||||
cookie := loginAs(t, srv, uid)
|
||||
|
||||
mid := srv.requireRole(store.RoleOperator)
|
||||
h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil)
|
||||
req.AddCookie(cookie)
|
||||
h.ServeHTTP(rr, req)
|
||||
if rr.Code != stdhttp.StatusForbidden {
|
||||
t.Errorf("status: got %d want 403", rr.Code)
|
||||
}
|
||||
if !strings.Contains(rr.Body.String(), "insufficient_role") {
|
||||
t.Errorf("body: got %q", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleUnauthenticated401OnAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newTestServer(t, false)
|
||||
|
||||
mid := srv.requireRole(store.RoleViewer)
|
||||
h := mid(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
w.WriteHeader(stdhttp.StatusOK)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := stdhttp.NewRequest("GET", "/api/dummy", nil)
|
||||
h.ServeHTTP(rr, req)
|
||||
if rr.Code != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("status: got %d want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleRejectsDisabledMidSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, urlBase := newTestServer(t, false)
|
||||
uid := makeUser(t, srv, "victim", store.RoleOperator)
|
||||
cookie := loginAs(t, srv, uid)
|
||||
|
||||
// Disable the user *while their session is still valid*.
|
||||
if err := srv.deps.Store.DisableUser(t.Context(), uid, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("disable: %v", err)
|
||||
}
|
||||
|
||||
req, _ := stdhttp.NewRequest("GET", urlBase+"/api/hosts", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GET: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("status: got %d want 401", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginRejectsDisabledUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, urlBase := newTestServer(t, false)
|
||||
uid := makeUser(t, srv, "disabled1", store.RoleOperator)
|
||||
if err := srv.deps.Store.DisableUser(t.Context(), uid, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("disable: %v", err)
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"username": "disabled1", "password": "test-password",
|
||||
})
|
||||
res, err := stdhttp.Post(urlBase+"/api/auth/login", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("POST: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusUnauthorized {
|
||||
t.Errorf("status: got %d want 401", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminBandRejectsOperator(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, urlBase := newTestServer(t, false)
|
||||
makeUser(t, srv, "admin1", store.RoleAdmin)
|
||||
opID := makeUser(t, srv, "op1", store.RoleOperator)
|
||||
cookie := loginAs(t, srv, opID)
|
||||
|
||||
req, _ := stdhttp.NewRequest("GET", urlBase+"/api/users", nil)
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GET: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusForbidden {
|
||||
t.Errorf("status: got %d want 403", res.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func toRepoMaintenanceView(m store.HostRepoMaintenance) repoMaintenanceView {
|
||||
|
||||
func (s *Server) handleGetRepoMaintenance(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -84,7 +84,7 @@ type repoMaintenanceWriteRequest struct {
|
||||
|
||||
func (s *Server) handleUpdateRepoMaintenance(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
|
||||
@@ -26,7 +26,7 @@ func (s *Server) handleRunRepoPrune(w stdhttp.ResponseWriter, r *stdhttp.Request
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -52,7 +52,7 @@ func (s *Server) handleRunRepoPrune(w stdhttp.ResponseWriter, r *stdhttp.Request
|
||||
return
|
||||
}
|
||||
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobPrune, nil,
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobPrune,
|
||||
api.CommandRunPayload{RequiresAdminCreds: true})
|
||||
if code != "" {
|
||||
s.runOpError(w, r, status, code, msg)
|
||||
@@ -72,7 +72,7 @@ func (s *Server) handleRunRepoCheck(w stdhttp.ResponseWriter, r *stdhttp.Request
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -107,7 +107,7 @@ func (s *Server) handleRunRepoCheck(w stdhttp.ResponseWriter, r *stdhttp.Request
|
||||
// Non-numeric ?subset silently falls back to DB value.
|
||||
}
|
||||
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobCheck, nil,
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobCheck,
|
||||
api.CommandRunPayload{Args: []string{strconv.Itoa(subset)}})
|
||||
if code != "" {
|
||||
s.runOpError(w, r, status, code, msg)
|
||||
@@ -125,7 +125,7 @@ func (s *Server) handleRunRepoUnlock(w stdhttp.ResponseWriter, r *stdhttp.Reques
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -134,7 +134,7 @@ func (s *Server) handleRunRepoUnlock(w stdhttp.ResponseWriter, r *stdhttp.Reques
|
||||
return
|
||||
}
|
||||
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobUnlock, nil,
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobUnlock,
|
||||
api.CommandRunPayload{})
|
||||
if code != "" {
|
||||
s.runOpError(w, r, status, code, msg)
|
||||
|
||||
@@ -9,7 +9,6 @@ package http
|
||||
import (
|
||||
"errors"
|
||||
stdhttp "net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
@@ -17,34 +16,6 @@ import (
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// parseBandwidthOverride pulls optional bandwidth_up_kbps /
|
||||
// bandwidth_down_kbps from the request (form or query). Returns nil
|
||||
// for any field absent or empty; an explicit "0" produces a non-nil
|
||||
// pointer to 0 — i.e., "no cap for this run, even if the host has
|
||||
// one set." Non-integers / negatives are rejected with an error.
|
||||
func parseBandwidthOverride(r *stdhttp.Request) (up *int, down *int, err error) {
|
||||
parse := func(name string) (*int, error) {
|
||||
v := r.FormValue(name)
|
||||
if v == "" {
|
||||
return nil, nil
|
||||
}
|
||||
n, perr := strconv.Atoi(v)
|
||||
if perr != nil {
|
||||
return nil, errors.New(name + " must be an integer")
|
||||
}
|
||||
if n < 0 {
|
||||
return nil, errors.New(name + " must be >= 0")
|
||||
}
|
||||
return &n, nil
|
||||
}
|
||||
up, err = parse("bandwidth_up_kbps")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
down, err = parse("bandwidth_down_kbps")
|
||||
return up, down, err
|
||||
}
|
||||
|
||||
func (s *Server) handleRunSourceGroup(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
user, ok := s.requireUser(r)
|
||||
if !ok {
|
||||
@@ -53,7 +24,7 @@ func (s *Server) handleRunSourceGroup(w stdhttp.ResponseWriter, r *stdhttp.Reque
|
||||
stdhttp.Redirect(w, r, "/login", stdhttp.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -69,34 +40,13 @@ func (s *Server) handleRunSourceGroup(w stdhttp.ResponseWriter, r *stdhttp.Reque
|
||||
return
|
||||
}
|
||||
|
||||
// Optional per-run bandwidth override. Disclosed in the UI under a
|
||||
// <details> "Limit bandwidth for this run" affordance; absent on
|
||||
// the wire (and from JSON callers that don't supply it) means
|
||||
// "fall back to the host's standing caps."
|
||||
upOverride, downOverride, perr := parseBandwidthOverride(r)
|
||||
if perr != nil {
|
||||
s.runGroupError(w, r, stdhttp.StatusBadRequest, "invalid_value", perr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve hooks (group → host default → empty). Best-effort host
|
||||
// lookup; failure proceeds with no hook rather than block the run.
|
||||
var preHook, postHook string
|
||||
if host, herr := s.deps.Store.GetHost(r.Context(), hostID); herr == nil {
|
||||
preHook, postHook = s.resolveBackupHooks(host, g)
|
||||
}
|
||||
|
||||
// Backup invocations don't consume RetentionPolicy — that lives on
|
||||
// forget. Sending the resolved set here would just be dead weight.
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobBackup, &g.ID,
|
||||
res, status, code, msg := s.dispatchJobWithPayload(r.Context(), user, hostID, api.JobBackup,
|
||||
api.CommandRunPayload{
|
||||
Includes: g.Includes,
|
||||
Excludes: g.Excludes,
|
||||
Tag: g.Name,
|
||||
BandwidthUpKBps: upOverride,
|
||||
BandwidthDownKBps: downOverride,
|
||||
PreHook: preHook,
|
||||
PostHook: postHook,
|
||||
Includes: g.Includes,
|
||||
Excludes: g.Excludes,
|
||||
Tag: g.Name,
|
||||
})
|
||||
if code != "" {
|
||||
s.runGroupError(w, r, status, code, msg)
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
// run_group_bandwidth_test.go — covers the per-job bandwidth override
|
||||
// that operators can set via the Run-now form's "Limit bandwidth for
|
||||
// this run" disclosure (P2R-13b).
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// TestRunSourceGroupBandwidthOverride: connect a fake agent, POST the
|
||||
// per-group Run-now endpoint with bandwidth_up_kbps=512, assert the
|
||||
// dispatched command.run carries it.
|
||||
func TestRunSourceGroupBandwidthOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, ts, st := rawTestServer(t)
|
||||
hostID, token := enrolHostForWS(t, srv, st, "bw-host")
|
||||
|
||||
// Pre-seed an init job so auto-init doesn't fire on hello and
|
||||
// pollute our envelope sequence.
|
||||
if err := st.CreateJob(context.Background(), store.Job{
|
||||
ID: ulid.Make().String(), HostID: hostID, Kind: "init",
|
||||
ActorKind: "system", CreatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("seed init: %v", err)
|
||||
}
|
||||
|
||||
gid := ulid.Make().String()
|
||||
if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{
|
||||
ID: gid, HostID: hostID, Name: "etc", Includes: []string{"/etc"},
|
||||
}); err != nil {
|
||||
t.Fatalf("group: %v", err)
|
||||
}
|
||||
|
||||
c := agentDial(t, srv, ts, hostID, token)
|
||||
sendHello(t, c, "bw-host")
|
||||
// Drain on-hello burst before issuing the run-now.
|
||||
_ = drainUntil(t, c, api.MsgScheduleSet)
|
||||
|
||||
cookie := loginAsAdmin(t, st)
|
||||
form := url.Values{
|
||||
"bandwidth_up_kbps": {"512"},
|
||||
"bandwidth_down_kbps": {"256"},
|
||||
}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
ts.URL+"/hosts/"+hostID+"/source-groups/"+gid+"/run",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusAccepted {
|
||||
t.Fatalf("status: got %d, want 202", res.StatusCode)
|
||||
}
|
||||
|
||||
// Read the dispatched command.run; assert overrides are present.
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond)
|
||||
mt, raw, rerr := c.Read(ctx)
|
||||
cancel()
|
||||
if rerr != nil {
|
||||
break
|
||||
}
|
||||
if mt != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
var env api.Envelope
|
||||
_ = json.Unmarshal(raw, &env)
|
||||
if env.Type != api.MsgCommandRun {
|
||||
continue
|
||||
}
|
||||
var p api.CommandRunPayload
|
||||
if err := env.UnmarshalPayload(&p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.Kind != api.JobBackup {
|
||||
continue
|
||||
}
|
||||
if p.BandwidthUpKBps == nil || *p.BandwidthUpKBps != 512 {
|
||||
t.Fatalf("BandwidthUpKBps: got %v, want 512", p.BandwidthUpKBps)
|
||||
}
|
||||
if p.BandwidthDownKBps == nil || *p.BandwidthDownKBps != 256 {
|
||||
t.Fatalf("BandwidthDownKBps: got %v, want 256", p.BandwidthDownKBps)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatal("timed out waiting for command.run with bandwidth override")
|
||||
}
|
||||
|
||||
// TestRunSourceGroupBandwidthRejectsNegative: invalid value → 400.
|
||||
func TestRunSourceGroupBandwidthRejectsNegative(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, url2, st := newTestServerWithHub(t)
|
||||
cookie := loginAsAdmin(t, st)
|
||||
hostID := makeHost(t, st, "bw-rej-host")
|
||||
gid := ulid.Make().String()
|
||||
if err := st.CreateSourceGroup(context.Background(), &store.SourceGroup{
|
||||
ID: gid, HostID: hostID, Name: "etc", Includes: []string{"/etc"},
|
||||
}); err != nil {
|
||||
t.Fatalf("group: %v", err)
|
||||
}
|
||||
form := url.Values{"bandwidth_up_kbps": {"-1"}}
|
||||
req, _ := stdhttp.NewRequest("POST",
|
||||
url2+"/hosts/"+hostID+"/source-groups/"+gid+"/run",
|
||||
strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.AddCookie(cookie)
|
||||
res, err := stdhttp.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != stdhttp.StatusBadRequest {
|
||||
t.Fatalf("status: got %d, want 400", res.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
// schedule_nextrun_test.go — pin the cron parser → next-run shape we
|
||||
// rely on for the dashboard host row + schedules tab (P2R-14).
|
||||
package http
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCronParserNext(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
expr string
|
||||
from time.Time
|
||||
want time.Time
|
||||
}{
|
||||
{
|
||||
name: "daily at 03:00",
|
||||
expr: "0 3 * * *",
|
||||
from: time.Date(2026, 5, 4, 1, 0, 0, 0, time.UTC),
|
||||
want: time.Date(2026, 5, 4, 3, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "daily at 03:00 (after time of day → next day)",
|
||||
expr: "0 3 * * *",
|
||||
from: time.Date(2026, 5, 4, 5, 0, 0, 0, time.UTC),
|
||||
want: time.Date(2026, 5, 5, 3, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "every 15 minutes",
|
||||
expr: "*/15 * * * *",
|
||||
from: time.Date(2026, 5, 4, 1, 7, 0, 0, time.UTC),
|
||||
want: time.Date(2026, 5, 4, 1, 15, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
parsed, err := cronParser.Parse(c.expr)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %q: %v", c.expr, err)
|
||||
}
|
||||
got := parsed.Next(c.from)
|
||||
if !got.Equal(c.want) {
|
||||
t.Fatalf("Next(%v) = %v, want %v", c.from, got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -180,32 +180,18 @@ func (s *Server) dispatchBackupForGroupCore(ctx context.Context, conn *ws.Conn,
|
||||
jobID := ulid.Make().String()
|
||||
now := time.Now().UTC()
|
||||
scheduleRef := scheduleID
|
||||
groupRef := g.ID
|
||||
if err := s.deps.Store.CreateJob(ctx, store.Job{
|
||||
ID: jobID,
|
||||
HostID: hostID,
|
||||
Kind: string(api.JobBackup),
|
||||
ScheduledID: &scheduleRef,
|
||||
SourceGroupID: &groupRef,
|
||||
ActorKind: "schedule",
|
||||
CreatedAt: now,
|
||||
ID: jobID,
|
||||
HostID: hostID,
|
||||
Kind: string(api.JobBackup),
|
||||
ScheduledID: &scheduleRef,
|
||||
ActorKind: "schedule",
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
slog.Warn("schedule.fire: persist job", "host_id", hostID,
|
||||
"schedule_id", scheduleID, "group", g.Name, "err", err)
|
||||
return "", err
|
||||
}
|
||||
// Resolve pre/post hooks (group → host default → empty) so they
|
||||
// ride on the backup payload as plaintext. The host lookup is
|
||||
// cheap; failure here is non-fatal (we proceed without hooks
|
||||
// rather than block the backup).
|
||||
var preHook, postHook string
|
||||
if host, herr := s.deps.Store.GetHost(ctx, hostID); herr == nil {
|
||||
preHook, postHook = s.resolveBackupHooks(host, g)
|
||||
} else {
|
||||
slog.Warn("schedule.fire: load host for hook resolve",
|
||||
"host_id", hostID, "err", herr)
|
||||
}
|
||||
|
||||
// Backup ignores RetentionPolicy — the forget cadence lives on
|
||||
// host_repo_maintenance and is driven by the server-side ticker
|
||||
// (P2R-06). Don't ship the field on backup dispatches.
|
||||
@@ -215,8 +201,6 @@ func (s *Server) dispatchBackupForGroupCore(ctx context.Context, conn *ws.Conn,
|
||||
Includes: g.Includes,
|
||||
Excludes: g.Excludes,
|
||||
Tag: g.Name,
|
||||
PreHook: preHook,
|
||||
PostHook: postHook,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("schedule.fire: marshal command.run",
|
||||
|
||||
@@ -61,7 +61,7 @@ var cronParser = cron.NewParser(
|
||||
|
||||
func (s *Server) handleListSchedules(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -89,7 +89,7 @@ func (s *Server) handleListSchedules(w stdhttp.ResponseWriter, r *stdhttp.Reques
|
||||
|
||||
func (s *Server) handleCreateSchedule(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -126,7 +126,7 @@ func (s *Server) handleCreateSchedule(w stdhttp.ResponseWriter, r *stdhttp.Reque
|
||||
|
||||
func (s *Server) handleUpdateSchedule(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
@@ -173,7 +173,7 @@ func (s *Server) handleUpdateSchedule(w stdhttp.ResponseWriter, r *stdhttp.Reque
|
||||
|
||||
func (s *Server) handleDeleteSchedule(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||||
if !s.authedUser(r) {
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorised", "")
|
||||
writeJSONError(w, stdhttp.StatusUnauthorized, "unauthorized", "")
|
||||
return
|
||||
}
|
||||
hostID := chi.URLParam(r, "id")
|
||||
|
||||
+154
-207
@@ -13,11 +13,8 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/alert"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/notification"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/config"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/oidc"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ui"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/server/ws"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
@@ -32,13 +29,6 @@ type Deps struct {
|
||||
Hub *ws.Hub
|
||||
JobHub *ws.JobHub
|
||||
UI *ui.Renderer
|
||||
// AlertEngine (optional, wired in G1) receives job-finished and
|
||||
// host-online events from the WS handler. Nil until G1 constructs
|
||||
// the engine at boot.
|
||||
AlertEngine *alert.Engine
|
||||
// NotificationHub (optional, wired in G1) is used by the test-fire
|
||||
// endpoint to dispatch a single synthetic payload through a channel.
|
||||
NotificationHub *notification.Hub
|
||||
// Version is the binary's build version, surfaced in the chrome.
|
||||
// Empty falls back to "dev".
|
||||
Version string
|
||||
@@ -46,9 +36,6 @@ type Deps struct {
|
||||
// admin-bootstrap token printed in the server logs. While set, the
|
||||
// /bootstrap endpoint accepts it to create the first admin user.
|
||||
BootstrapToken string
|
||||
// OIDC (optional). Non-nil when the operator has configured an
|
||||
// IdP — handlers under /auth/oidc/* are mounted only when set.
|
||||
OIDC *oidc.Client
|
||||
}
|
||||
|
||||
// Server is the running HTTP server.
|
||||
@@ -56,26 +43,12 @@ type Server struct {
|
||||
srv *stdhttp.Server
|
||||
deps Deps
|
||||
|
||||
// drainLocks serialises DrainPending per host. The on-hello
|
||||
// drainLocks serializes DrainPending per host. The on-hello
|
||||
// goroutine and the 30s ticker can otherwise race for the same
|
||||
// host, double-dispatching every pending row. Map of hostID →
|
||||
// sync.Mutex; checked-and-locked atomically via drainLocksMu.
|
||||
drainLocksMu sync.Mutex
|
||||
drainLocks map[string]*sync.Mutex
|
||||
|
||||
// announceRL is the per-source-IP token-bucket guarding
|
||||
// POST /api/agents/announce (P2-18). One process-local map.
|
||||
announceRL *announceLimiter
|
||||
|
||||
// pendingHub holds live /ws/agent/pending sockets keyed by
|
||||
// pending_id so the accept/reject handlers can push the bearer
|
||||
// or close cleanly (P2-18b).
|
||||
pendingHub *pendingHub
|
||||
|
||||
// treeCache holds per-wizard-session listings of snapshot
|
||||
// directories (P3-X2). Pre-allocated in New so the lazy-init
|
||||
// race is impossible.
|
||||
treeCache *treeCache
|
||||
}
|
||||
|
||||
// New builds a configured but not-yet-started server.
|
||||
@@ -89,13 +62,12 @@ func New(deps Deps) *Server {
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(requestLogger)
|
||||
|
||||
s := &Server{
|
||||
deps: deps,
|
||||
drainLocks: make(map[string]*sync.Mutex),
|
||||
announceRL: newAnnounceLimiter(),
|
||||
pendingHub: newPendingHub(),
|
||||
treeCache: newTreeCache(),
|
||||
}
|
||||
// Health endpoint — unauthenticated, no audit, deliberately cheap.
|
||||
r.Get("/healthz", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
})
|
||||
|
||||
s := &Server{deps: deps, drainLocks: make(map[string]*sync.Mutex)}
|
||||
s.routes(r)
|
||||
|
||||
s.srv = &stdhttp.Server{
|
||||
@@ -112,196 +84,171 @@ func New(deps Deps) *Server {
|
||||
// routes wires the API tree. Subtrees live in this file by area so a
|
||||
// reader can scan one place and see the surface.
|
||||
func (s *Server) routes(r chi.Router) {
|
||||
// Public, unauthenticated.
|
||||
r.Get("/healthz", func(w stdhttp.ResponseWriter, _ *stdhttp.Request) {
|
||||
w.WriteHeader(stdhttp.StatusNoContent)
|
||||
r.Route("/api", func(r chi.Router) {
|
||||
r.Post("/auth/login", s.handleLogin)
|
||||
r.Post("/auth/logout", s.handleLogout)
|
||||
r.Post("/bootstrap", s.handleBootstrap)
|
||||
|
||||
// Agent enrollment (open endpoint — token is the credential).
|
||||
r.Post("/agents/enroll", s.handleAgentEnroll)
|
||||
|
||||
// Operator → server (authenticated). Spec.md §6.1's
|
||||
// /hosts/{id}/enrollment-token (regenerate) lands when the
|
||||
// host page can call it; for now just the create endpoint.
|
||||
r.Post("/enrollment-tokens", s.handleCreateEnrollmentToken)
|
||||
|
||||
// Fleet read endpoints — back the dashboard.
|
||||
r.Get("/hosts", s.handleListHosts)
|
||||
r.Get("/fleet/summary", s.handleFleetSummary)
|
||||
|
||||
// Run-now: dispatch a job to a host's agent.
|
||||
r.Post("/hosts/{id}/jobs", s.handleRunNow)
|
||||
|
||||
// Snapshot projection (refreshed by the agent after each backup).
|
||||
r.Get("/hosts/{id}/snapshots", s.handleListHostSnapshots)
|
||||
|
||||
// Repo credentials — operator can edit after enrollment. The
|
||||
// initial set is supplied at token-mint time (see enrollment.go).
|
||||
// GET returns a redacted view (URL, username, has_password).
|
||||
r.Get("/hosts/{id}/repo-credentials", s.handleGetHostCredentials)
|
||||
r.Put("/hosts/{id}/repo-credentials", s.handleSetHostCredentials)
|
||||
|
||||
// Admin credentials — the prune-capable slot (separate from the
|
||||
// everyday repo creds). Optional: hosts that don't prune against
|
||||
// a rest-server repo with a separate admin user never need this.
|
||||
r.Get("/hosts/{id}/admin-credentials", s.handleGetAdminCredentials)
|
||||
r.Put("/hosts/{id}/admin-credentials", s.handleSetAdminCredentials)
|
||||
r.Delete("/hosts/{id}/admin-credentials", s.handleDeleteAdminCredentials)
|
||||
|
||||
// Per-host schedule CRUD. Mutations bump host_schedule_version
|
||||
// and async-push to a connected agent (see schedule_push.go).
|
||||
r.Get("/hosts/{id}/schedules", s.handleListSchedules)
|
||||
r.Post("/hosts/{id}/schedules", s.handleCreateSchedule)
|
||||
r.Put("/hosts/{id}/schedules/{sid}", s.handleUpdateSchedule)
|
||||
r.Delete("/hosts/{id}/schedules/{sid}", s.handleDeleteSchedule)
|
||||
|
||||
// Source-group CRUD. A group is "what gets backed up" — paths,
|
||||
// excludes, retention, retry. Group name doubles as the
|
||||
// snapshot tag (restic --tag <name>).
|
||||
r.Get("/hosts/{id}/source-groups", s.handleListSourceGroups)
|
||||
r.Post("/hosts/{id}/source-groups", s.handleCreateSourceGroup)
|
||||
r.Get("/hosts/{id}/source-groups/{gid}", s.handleGetSourceGroup)
|
||||
r.Put("/hosts/{id}/source-groups/{gid}", s.handleUpdateSourceGroup)
|
||||
r.Delete("/hosts/{id}/source-groups/{gid}", s.handleDeleteSourceGroup)
|
||||
|
||||
// Repo maintenance cadences (forget / prune / check). Driven
|
||||
// by the server-side ticker (P2R-06), not the agent's cron.
|
||||
r.Get("/hosts/{id}/repo-maintenance", s.handleGetRepoMaintenance)
|
||||
r.Put("/hosts/{id}/repo-maintenance", s.handleUpdateRepoMaintenance)
|
||||
|
||||
// Host-wide bandwidth caps (host.bandwidth_up_kbps /
|
||||
// bandwidth_down_kbps). Apply to every restic invocation.
|
||||
r.Put("/hosts/{id}/bandwidth", s.handleUpdateHostBandwidth)
|
||||
|
||||
// Per-source-group Run-now (JSON variant). HTMX action is
|
||||
// mounted at the equivalent path outside /api below — both
|
||||
// resolve to the same handler, which sniffs HX-Request.
|
||||
r.Post("/hosts/{id}/source-groups/{gid}/run", s.handleRunSourceGroup)
|
||||
|
||||
// Repo-level run-now: prune (needs admin creds), check, unlock.
|
||||
// HTMX forms are also mounted outside /api below.
|
||||
r.Post("/hosts/{id}/repo/prune", s.handleRunRepoPrune)
|
||||
r.Post("/hosts/{id}/repo/check", s.handleRunRepoCheck)
|
||||
r.Post("/hosts/{id}/repo/unlock", s.handleRunRepoUnlock)
|
||||
})
|
||||
r.Post("/api/auth/login", s.handleLogin)
|
||||
r.Post("/api/auth/logout", s.handleLogout)
|
||||
r.Post("/api/bootstrap", s.handleBootstrap)
|
||||
r.Post("/api/agents/enroll", s.handleAgentEnroll)
|
||||
r.Post("/api/agents/announce", s.handleAnnounce)
|
||||
r.Get("/agent/binary", s.handleAgentBinary)
|
||||
r.Get("/install/*", s.handleInstallAsset)
|
||||
|
||||
// Per-source-group Run-now (HTMX form action). Available even
|
||||
// when the server is started without UI templates so REST callers
|
||||
// against the non-/api path also work.
|
||||
r.Post("/hosts/{id}/source-groups/{gid}/run", s.handleRunSourceGroup)
|
||||
// Repo-level run-now (HTMX form actions). Same handlers as the /api
|
||||
// variants — wantsHTML sniff distinguishes JSON vs HTMX response.
|
||||
r.Post("/hosts/{id}/repo/prune", s.handleRunRepoPrune)
|
||||
r.Post("/hosts/{id}/repo/check", s.handleRunRepoCheck)
|
||||
r.Post("/hosts/{id}/repo/unlock", s.handleRunRepoUnlock)
|
||||
// Retired routes — see ui_handlers.go for the messages. Mounted
|
||||
// outside the UI gate so cached browser tabs get a clear 410
|
||||
// even if the server runs without templates.
|
||||
r.Post("/hosts/{id}/run-backup", s.handleUIRunBackupGone)
|
||||
r.Post("/hosts/{id}/init-repo", s.handleUIInitRepoGone)
|
||||
|
||||
// Agent ↔ server WebSocket. Bearer-authenticated inside the handler.
|
||||
if s.deps.Hub != nil {
|
||||
r.Mount("/ws/agent", ws.AgentHandler(ws.HandlerDeps{
|
||||
Hub: s.deps.Hub,
|
||||
Store: s.deps.Store,
|
||||
JobHub: s.deps.JobHub,
|
||||
AlertEngine: s.deps.AlertEngine,
|
||||
OnHello: s.onAgentHello,
|
||||
OnScheduleAck: s.applyScheduleAck,
|
||||
OnScheduleFire: s.dispatchScheduledJob,
|
||||
}))
|
||||
}
|
||||
r.Get("/ws/agent/pending", s.handlePendingWS)
|
||||
|
||||
// Agent binaries + install scripts. Open endpoints — content is
|
||||
// unprivileged on its own, gating happens via the enrollment
|
||||
// token. See agent_assets.go.
|
||||
r.Get("/agent/binary", s.handleAgentBinary)
|
||||
r.Get("/install/*", s.handleInstallAsset)
|
||||
|
||||
// Static assets (Tailwind CSS bundle, future favicon).
|
||||
r.Mount("/static/", staticHandler())
|
||||
|
||||
// POST /logout is always mounted — it handles both local and OIDC
|
||||
// sessions and doesn't require the UI renderer.
|
||||
r.Post("/logout", s.handleUILogoutPost)
|
||||
// HTML UI. The renderer is required — fail loud if the binary
|
||||
// was built without templates (impossible in practice given
|
||||
// embed, but guards bad test wiring).
|
||||
if s.deps.UI != nil {
|
||||
r.Get("/bootstrap", s.handleUIBootstrapGet)
|
||||
r.Post("/bootstrap", s.handleUIBootstrapPost)
|
||||
r.Get("/", s.handleUIDashboard)
|
||||
r.Get("/login", s.handleUILoginGet)
|
||||
r.Post("/login", s.handleUILoginPost)
|
||||
r.Get("/setup", s.handleUISetupGet)
|
||||
r.Post("/setup", s.handleUISetupPost)
|
||||
}
|
||||
if s.deps.OIDC != nil {
|
||||
r.Get("/auth/oidc/login", s.handleOIDCLogin)
|
||||
r.Get("/auth/oidc/callback", s.handleOIDCCallback)
|
||||
r.Post("/logout", s.handleUILogoutPost)
|
||||
// Per-host Run-now and manual Init-repo are mounted at the
|
||||
// outer router (so they reply 410 even without UI). Per-
|
||||
// source-group Run-now lives there too — same reason.
|
||||
// Add host flow.
|
||||
r.Get("/hosts/new", s.handleUIAddHostGet)
|
||||
r.Post("/hosts/new", s.handleUIAddHostPost)
|
||||
// Durable post-Add-host page (operator can refresh / come
|
||||
// back; password decrypted from the token row each render).
|
||||
// Polled fragment under /awaiting flips to "connected" once
|
||||
// the agent enrolls.
|
||||
r.Get("/hosts/pending/{token}", s.handleUIPendingHost)
|
||||
r.Get("/hosts/pending/{token}/awaiting", s.handleUIPendingAwaiting)
|
||||
// Host detail (Snapshots tab is the default).
|
||||
r.Get("/hosts/{id}", s.handleUIHostDetail)
|
||||
// Sources tab + source-group CRUD forms.
|
||||
r.Get("/hosts/{id}/sources", s.handleUIHostSources)
|
||||
r.Get("/hosts/{id}/sources/new", s.handleUISourceGroupNewGet)
|
||||
r.Post("/hosts/{id}/sources/new", s.handleUISourceGroupSave)
|
||||
r.Get("/hosts/{id}/sources/{gid}/edit", s.handleUISourceGroupEditGet)
|
||||
r.Post("/hosts/{id}/sources/{gid}/edit", s.handleUISourceGroupSave)
|
||||
r.Post("/hosts/{id}/sources/{gid}/delete", s.handleUISourceGroupDelete)
|
||||
// Repo tab — connection / bandwidth / maintenance. Three
|
||||
// independent forms so saving one doesn't touch the others.
|
||||
r.Get("/hosts/{id}/repo", s.handleUIHostRepo)
|
||||
r.Post("/hosts/{id}/repo/credentials", s.handleUIRepoCredentialsSave)
|
||||
r.Post("/hosts/{id}/repo/bandwidth", s.handleUIRepoBandwidthSave)
|
||||
r.Post("/hosts/{id}/repo/maintenance", s.handleUIRepoMaintenanceSave)
|
||||
// Admin credentials form (separate slot for prune-capable user).
|
||||
r.Post("/hosts/{id}/admin-credentials", s.handleUIAdminCredentialsSave)
|
||||
r.Post("/hosts/{id}/admin-credentials/delete", s.handleUIAdminCredentialsDelete)
|
||||
// Schedules tab + create/edit/delete forms.
|
||||
r.Get("/hosts/{id}/schedules", s.handleUISchedulesList)
|
||||
r.Get("/hosts/{id}/schedules/new", s.handleUIScheduleNewGet)
|
||||
r.Post("/hosts/{id}/schedules/new", s.handleUIScheduleSave)
|
||||
r.Get("/hosts/{id}/schedules/{sid}/edit", s.handleUIScheduleEditGet)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/edit", s.handleUIScheduleSave)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/delete", s.handleUIScheduleDelete)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/run", s.handleUIScheduleRun)
|
||||
// Live job log.
|
||||
r.Get("/jobs/{id}", s.handleUIJobDetail)
|
||||
}
|
||||
|
||||
// Viewer band — anyone authenticated can read.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.requireRole(store.RoleViewer))
|
||||
|
||||
// Read APIs.
|
||||
r.Get("/api/hosts", s.handleListHosts)
|
||||
r.Get("/api/fleet/summary", s.handleFleetSummary)
|
||||
r.Get("/api/hosts/{id}/snapshots", s.handleListHostSnapshots)
|
||||
r.Get("/api/hosts/{id}/repo-credentials", s.handleGetHostCredentials)
|
||||
r.Get("/api/hosts/{id}/admin-credentials", s.handleGetAdminCredentials)
|
||||
r.Get("/api/hosts/{id}/schedules", s.handleListSchedules)
|
||||
r.Get("/api/hosts/{id}/source-groups", s.handleListSourceGroups)
|
||||
r.Get("/api/hosts/{id}/source-groups/{gid}", s.handleGetSourceGroup)
|
||||
r.Get("/api/hosts/{id}/repo-maintenance", s.handleGetRepoMaintenance)
|
||||
r.Get("/api/alerts", s.handleAPIAlerts)
|
||||
r.Get("/api/audit", s.handleAPIAudit)
|
||||
r.Post("/api/account/password", s.handleAPIAccountPassword)
|
||||
|
||||
// Job log stream + download (read-only; any authenticated user).
|
||||
if s.deps.JobHub != nil {
|
||||
r.Get("/api/jobs/{id}/stream", s.handleJobStream)
|
||||
}
|
||||
r.Get("/api/jobs/{id}/log.{format:txt|ndjson}", s.handleJobLogDownload)
|
||||
|
||||
if s.deps.UI != nil {
|
||||
r.Get("/", s.handleUIDashboard)
|
||||
r.Get("/hosts/{id}", s.handleUIHostDetail)
|
||||
r.Get("/hosts/{id}/sources", s.handleUIHostSources)
|
||||
r.Get("/hosts/{id}/sources/new", s.handleUISourceGroupNewGet)
|
||||
r.Get("/hosts/{id}/sources/{gid}/edit", s.handleUISourceGroupEditGet)
|
||||
r.Get("/hosts/{id}/repo", s.handleUIHostRepo)
|
||||
r.Get("/hosts/{id}/schedules", s.handleUISchedulesList)
|
||||
r.Get("/hosts/{id}/schedules/new", s.handleUIScheduleNewGet)
|
||||
r.Get("/hosts/{id}/schedules/{sid}/edit", s.handleUIScheduleEditGet)
|
||||
r.Get("/jobs/{id}", s.handleUIJobDetail)
|
||||
r.Get("/hosts/{id}/restore", s.handleUIRestoreGet)
|
||||
r.Get("/hosts/{id}/snapshots/{sid}/restore", s.handleUIRestoreGet)
|
||||
r.Get("/hosts/{id}/restore/tree", s.handleUIRestoreTree)
|
||||
r.Get("/alerts", s.handleUIAlerts)
|
||||
r.Get("/audit", s.handleUIAudit)
|
||||
r.Get("/audit.csv", s.handleUIAuditCSV)
|
||||
r.Get("/settings/account", s.handleUIAccountGet)
|
||||
r.Post("/settings/account", s.handleUIAccountPost)
|
||||
}
|
||||
})
|
||||
|
||||
// Operator band — mutating endpoints up to backup ops.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.requireRole(store.RoleOperator))
|
||||
|
||||
// Pending hosts approval.
|
||||
r.Post("/api/pending-hosts/{id}/accept", s.handleAcceptPendingHost)
|
||||
r.Post("/api/pending-hosts/{id}/reject", s.handleRejectPendingHost)
|
||||
r.Post("/api/enrollment-tokens", s.handleCreateEnrollmentToken)
|
||||
r.Post("/hosts/enrollment-tokens/{hash}/regenerate", s.handleUIEnrollmentTokenRegenerate)
|
||||
r.Post("/hosts/enrollment-tokens/{hash}/revoke", s.handleUIEnrollmentTokenRevoke)
|
||||
|
||||
// Run-now, restore, repo ops (JSON).
|
||||
r.Post("/api/hosts/{id}/jobs", s.handleRunNow)
|
||||
r.Put("/api/hosts/{id}/repo-credentials", s.handleSetHostCredentials)
|
||||
r.Put("/api/hosts/{id}/admin-credentials", s.handleSetAdminCredentials)
|
||||
r.Delete("/api/hosts/{id}/admin-credentials", s.handleDeleteAdminCredentials)
|
||||
r.Post("/api/hosts/{id}/schedules", s.handleCreateSchedule)
|
||||
r.Put("/api/hosts/{id}/schedules/{sid}", s.handleUpdateSchedule)
|
||||
r.Delete("/api/hosts/{id}/schedules/{sid}", s.handleDeleteSchedule)
|
||||
r.Post("/api/hosts/{id}/source-groups", s.handleCreateSourceGroup)
|
||||
r.Put("/api/hosts/{id}/source-groups/{gid}", s.handleUpdateSourceGroup)
|
||||
r.Delete("/api/hosts/{id}/source-groups/{gid}", s.handleDeleteSourceGroup)
|
||||
r.Put("/api/hosts/{id}/repo-maintenance", s.handleUpdateRepoMaintenance)
|
||||
r.Put("/api/hosts/{id}/bandwidth", s.handleUpdateHostBandwidth)
|
||||
r.Post("/api/hosts/{id}/source-groups/{gid}/run", s.handleRunSourceGroup)
|
||||
r.Post("/api/hosts/{id}/repo/prune", s.handleRunRepoPrune)
|
||||
r.Post("/api/hosts/{id}/repo/check", s.handleRunRepoCheck)
|
||||
r.Post("/api/hosts/{id}/repo/unlock", s.handleRunRepoUnlock)
|
||||
r.Post("/api/jobs/{id}/cancel", s.handleCancelJob)
|
||||
r.Post("/api/hosts/{id}/snapshots/diff", s.handleSnapshotDiff)
|
||||
|
||||
// HTMX form variants outside /api.
|
||||
r.Post("/hosts/{id}/snapshots/diff", s.handleSnapshotDiff)
|
||||
r.Post("/hosts/{id}/source-groups/{gid}/run", s.handleRunSourceGroup)
|
||||
r.Post("/hosts/{id}/repo/prune", s.handleRunRepoPrune)
|
||||
r.Post("/hosts/{id}/repo/check", s.handleRunRepoCheck)
|
||||
r.Post("/hosts/{id}/repo/unlock", s.handleRunRepoUnlock)
|
||||
r.Post("/hosts/{id}/run-backup", s.handleUIRunBackupGone)
|
||||
r.Post("/hosts/{id}/init-repo", s.handleUIInitRepoGone)
|
||||
|
||||
if s.deps.UI != nil {
|
||||
r.Get("/hosts/new", s.handleUIAddHostGet)
|
||||
r.Post("/hosts/new", s.handleUIAddHostPost)
|
||||
r.Get("/hosts/pending/{token}", s.handleUIPendingHost)
|
||||
r.Get("/hosts/pending/{token}/awaiting", s.handleUIPendingAwaiting)
|
||||
r.Post("/hosts/{id}/sources/new", s.handleUISourceGroupSave)
|
||||
r.Post("/hosts/{id}/sources/{gid}/edit", s.handleUISourceGroupSave)
|
||||
r.Post("/hosts/{id}/sources/{gid}/delete", s.handleUISourceGroupDelete)
|
||||
r.Post("/hosts/{id}/repo/credentials", s.handleUIRepoCredentialsSave)
|
||||
r.Post("/hosts/{id}/repo/bandwidth", s.handleUIRepoBandwidthSave)
|
||||
r.Post("/hosts/{id}/repo/maintenance", s.handleUIRepoMaintenanceSave)
|
||||
r.Post("/hosts/{id}/repo/reinit", s.handleUIRepoReinit)
|
||||
r.Post("/hosts/{id}/repo/probe", s.handleUIRepoProbe)
|
||||
r.Post("/hosts/{id}/repo/hooks", s.handleUIRepoHooksSave)
|
||||
r.Post("/hosts/{id}/tags", s.handleUIHostTagsSave)
|
||||
r.Post("/hosts/{id}/admin-credentials", s.handleUIAdminCredentialsSave)
|
||||
r.Post("/hosts/{id}/admin-credentials/delete", s.handleUIAdminCredentialsDelete)
|
||||
r.Post("/hosts/{id}/schedules/new", s.handleUIScheduleSave)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/edit", s.handleUIScheduleSave)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/delete", s.handleUIScheduleDelete)
|
||||
r.Post("/hosts/{id}/schedules/{sid}/run", s.handleUIScheduleRun)
|
||||
r.Post("/hosts/{id}/restore", s.handleUIRestorePost)
|
||||
r.Post("/alerts/{id}/acknowledge", s.handleUIAlertAcknowledge)
|
||||
r.Post("/alerts/{id}/resolve", s.handleUIAlertResolve)
|
||||
}
|
||||
})
|
||||
|
||||
// Admin band — channels, server-shape config.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.requireRole(store.RoleAdmin))
|
||||
|
||||
r.Get("/api/users", s.handleAPIUsersList)
|
||||
r.Post("/api/users", s.handleAPIUserCreate)
|
||||
r.Get("/api/users/{id}", s.handleAPIUserGet)
|
||||
r.Patch("/api/users/{id}", s.handleAPIUserPatch)
|
||||
r.Post("/api/users/{id}/disable", s.handleAPIUserDisable)
|
||||
r.Post("/api/users/{id}/enable", s.handleAPIUserEnable)
|
||||
r.Post("/api/users/{id}/regenerate-setup", s.handleAPIUserRegenerateSetup)
|
||||
r.Post("/api/users/{id}/force-logout", s.handleAPIUserForceLogout)
|
||||
r.Post("/api/notifications/{id}/test", s.handleAPINotificationTest)
|
||||
|
||||
if s.deps.UI != nil {
|
||||
r.Post("/hosts/{id}/delete", s.handleUIHostDelete)
|
||||
r.Get("/settings", s.handleUISettings)
|
||||
r.Get("/settings/users", s.handleUIUsersList)
|
||||
r.Get("/settings/users/new", s.handleUIUserNewGet)
|
||||
r.Post("/settings/users/new", s.handleUIUserNewPost)
|
||||
r.Get("/settings/users/{id}/edit", s.handleUIUserEditGet)
|
||||
r.Post("/settings/users/{id}/edit", s.handleUIUserEditPost)
|
||||
r.Post("/settings/users/{id}/disable", s.handleUIUserDisablePost)
|
||||
r.Post("/settings/users/{id}/enable", s.handleUIUserEnablePost)
|
||||
r.Post("/settings/users/{id}/regenerate-setup", s.handleUIUserRegenerateSetupPost)
|
||||
r.Post("/settings/users/{id}/force-logout", s.handleUIUserForceLogoutPost)
|
||||
r.Get("/settings/users/{id}/setup-link", s.handleUIUserSetupLinkGet)
|
||||
r.Get("/settings/notifications", s.handleUINotificationsList)
|
||||
r.Get("/settings/notifications/new", s.handleUINotificationNewGet)
|
||||
r.Post("/settings/notifications/new", s.handleUINotificationNewPost)
|
||||
r.Get("/settings/notifications/{id}/edit", s.handleUINotificationEditGet)
|
||||
r.Post("/settings/notifications/{id}/edit", s.handleUINotificationEditPost)
|
||||
r.Post("/settings/notifications/{id}/delete", s.handleUINotificationDelete)
|
||||
r.Post("/settings/notifications/{id}/toggle", s.handleUINotificationToggle)
|
||||
}
|
||||
})
|
||||
// Browser job-log stream (separate from /ws/agent so the auth
|
||||
// layer is session-cookie not bearer). Mounted regardless of
|
||||
// whether the UI is up — JSON callers may also subscribe.
|
||||
if s.deps.JobHub != nil {
|
||||
r.Get("/api/jobs/{id}/stream", s.handleJobStream)
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening. Blocks until ListenAndServe returns
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user