Files
restic-manager/internal/server/ws/handler.go
T
steve 608962441b P2-02 (agent side) + P2-03: agent scheduler + schedule.fire dispatch
Closes the schedule reconciliation loop end-to-end.

* New `internal/agent/scheduler` package wraps robfig/cron/v3 with
  the lifecycle the agent needs:
  - Apply(ScheduleSetPayload, Sender) stops the prior cron (waiting
    for in-flight entries to return), rebuilds from scratch, starts,
    and emits schedule.ack with the version we just applied.
  - Disabled entries skipped silently; bad cron exprs (which
    shouldn't reach us — the server validates — but defensive)
    log a warn and skip.
  - On each cron tick the entry sends a new schedule.fire envelope
    to the server with {schedule_id, scheduled_at}. The scheduler
    itself never builds CommandRunPayloads — server is the source
    of truth for jobs.
  - tx is swapped on every Apply, so reconnect is handled
    naturally: cron entries that fire against a dropped tx log
    "no active connection" and skip the tick.
  - Stop() is idempotent and waits for the cron's in-flight
    workers via cron.Stop().Done().

* New wire message api.MsgScheduleFire + api.ScheduleFirePayload
  for the agent → server "I just fired locally" RPC.

* Server-side dispatch (schedule_push.go: dispatchScheduledJob):
  looks up the schedule by id, validates ownership + that it's
  enabled, builds args from kind (paths for backup; other kinds
  are still arg-less in Phase 2 and grow as those job kinds land
  in P2-05..08), persists a jobs row with actor_kind=schedule +
  scheduled_id, and writes command.run back on the same conn so
  the agent runs through its existing dispatch path.

* store.CreateJob now writes scheduled_id. This column was in the
  schema since 0001 but never populated — the original P1 path
  only had operator-driven jobs, so actor_kind was always 'user'
  and scheduled_id was always nil.

* cmd/agent/main.go integration: dispatcher gains a
  *scheduler.Scheduler; the MsgScheduleSet case now hands the
  payload to scheduler.Apply (in a goroutine so the WS read loop
  keeps draining other messages).

* WS dispatcher gains OnScheduleFire alongside OnScheduleAck.

* Tests:
  - scheduler unit tests (4): ack-on-apply, cron tick fires
    schedule.fire envelope, disabled entries don't fire, replace-
    prior-state stops the old cron.
  - Server-side end-to-end: schedule.fire → command.run with the
    right job_id / kind / args, plus jobs row with actor_kind=
    "schedule" and scheduled_id linking back to the schedule.

Persistence of next-fire times across agent restarts is
deliberately deferred. A missed fire window during downtime
simply fires once on reconnect — that's the desirable behaviour
(the operator wants the missed backup to run, not be silently
skipped because we lost track of when it was due).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-02 11:29:12 +01:00

308 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ws
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
stdhttp "net/http"
"strings"
"time"
"github.com/coder/websocket"
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
)
// HandlerDeps is the set of collaborators the agent WS handler needs.
type HandlerDeps struct {
Hub *Hub
Store *store.Store
JobHub *JobHub
// OnHello is called once per successful hello, after the host row
// has been touched and the conn registered. Used by the HTTP
// layer to push host_credentials down as a config.update before
// the agent starts asking for jobs. Optional; nil = no-op.
OnHello func(ctx context.Context, hostID string, conn *Conn)
// OnScheduleAck is called when an agent confirms it has applied
// a particular schedule version (P2-02 reconciliation). Optional.
OnScheduleAck func(ctx context.Context, hostID string, version int64, appliedAt time.Time)
// OnScheduleFire is called when an agent's local cron fires. The
// callback is expected to look up the schedule, persist a job
// row, and emit MsgCommandRun back on conn so the agent can run
// the job using its normal job dispatch path. Optional.
OnScheduleFire func(ctx context.Context, hostID string, conn *Conn, scheduleID string, scheduledAt time.Time)
}
// AgentHandler is the http.Handler that owns /ws/agent. Agents
// authenticate with `Authorization: Bearer <token>` (issued at
// enrollment) before the WS upgrade.
//
// Lifecycle:
// 1. Bearer token resolves to a Host row.
// 2. Upgrade.
// 3. First message must be `hello`; protocol_version checked here.
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
// host row; job/log/repo messages forward to the relevant
// handlers (TODO: lands with P1-18 onward).
// 5. On Read error or context cancel, mark host offline, unregister
// from the hub.
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
host, ok := authenticateAgent(r, deps.Store)
if !ok {
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
})
if err != nil {
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
return
}
c := NewConn(host.ID, conn)
// Keep agents alive across NAT boxes; coder/websocket
// auto-pings under the hood when configured. The default 60s
// works fine for a 30s heartbeat cadence.
runAgentLoop(r.Context(), c, host.ID, deps)
})
}
// authenticateAgent returns the host that owns the bearer token in
// the request, or (nil, false) if anything is amiss. The same
// "false" path is used for missing header, malformed header, unknown
// token — no information leak about why.
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
hdr := r.Header.Get("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(hdr, prefix) {
return nil, false
}
token := strings.TrimPrefix(hdr, prefix)
if token == "" {
return nil, false
}
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
if err != nil {
return nil, false
}
return h, true
}
// runAgentLoop is the per-connection driver. Returns when the socket
// is closed for any reason. It owns the hub registration: register on
// hello acceptance, unregister on exit.
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
// Stage 1: hello (with a tight deadline).
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
hello, err := c.Read(helloCtx)
cancel()
if err != nil {
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
_ = c.Close()
return
}
if hello.Type != api.MsgHello {
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
return
}
var helloPayload api.HelloPayload
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
return
}
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
c.SendError(ctx, api.ErrProtocolTooOld,
fmt.Sprintf("agent protocol_version %d below minimum %d",
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
"https://restic-manager.example/docs/upgrade")
return
}
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
// Forward-compat is fine — newer agents talking to older
// servers should accept their lower version. Just log it.
slog.Info("ws agent newer than server",
"host_id", hostID,
"agent_proto", helloPayload.ProtocolVersion,
"server_proto", api.CurrentProtocolVersion)
}
now := time.Now().UTC()
if err := deps.Store.MarkHostHello(ctx, hostID,
helloPayload.AgentVersion, helloPayload.ResticVersion,
helloPayload.ProtocolVersion, now); err != nil {
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
}
deps.Hub.Register(hostID, c)
defer deps.Hub.Unregister(hostID, c)
defer func() { _ = c.Close() }()
slog.Info("ws agent connected",
"host_id", hostID,
"agent_version", helloPayload.AgentVersion,
"protocol_version", helloPayload.ProtocolVersion)
if deps.OnHello != nil {
// Run synchronously so the config.update lands before any
// command.run an operator might race in.
deps.OnHello(ctx, hostID, c)
}
// Stage 2: main read loop.
for {
env, err := c.Read(ctx)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
}
return
}
dispatchAgentMessage(ctx, c, hostID, env, deps)
}
}
// dispatchAgentMessage routes a single envelope to its handler.
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
switch env.Type {
case api.MsgHeartbeat:
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
case api.MsgJobStarted:
var p api.JobStartedPayload
_ = env.UnmarshalPayload(&p)
if err := deps.Store.MarkJobStarted(ctx, p.JobID, p.StartedAt); err != nil {
slog.Warn("ws: mark job started", "job_id", p.JobID, "err", err)
}
if deps.JobHub != nil {
deps.JobHub.Broadcast(p.JobID, env)
}
case api.MsgJobProgress:
// Progress ticks aren't persisted (1Hz × every job × every
// path-walk would dwarf the rest of the DB). The live UI
// subscribes to JobHub and gets them in real time; once a
// job finishes the final summary lands via job.finished.
var p api.JobProgressPayload
_ = env.UnmarshalPayload(&p)
if deps.JobHub != nil {
deps.JobHub.Broadcast(p.JobID, env)
}
case api.MsgJobFinished:
var p api.JobFinishedPayload
_ = env.UnmarshalPayload(&p)
errMsg := p.Error
if err := deps.Store.MarkJobFinished(ctx, p.JobID,
string(p.Status), p.ExitCode, p.Stats, errMsg, p.FinishedAt); err != nil {
slog.Warn("ws: mark job finished", "job_id", p.JobID, "err", err)
}
// A successful backup or init proves the repo exists; flip
// repo_initialised_at on the host (idempotent — set-if-null).
if p.Status == api.JobSucceeded {
if job, err := deps.Store.GetJob(ctx, p.JobID); err == nil &&
(job.Kind == string(api.JobBackup) || job.Kind == string(api.JobInit)) {
if _, err := deps.Store.MarkHostRepoInitialised(ctx, hostID, p.FinishedAt); err != nil {
slog.Warn("ws: mark repo initialised", "host_id", hostID, "err", err)
}
}
}
if deps.JobHub != nil {
deps.JobHub.Broadcast(p.JobID, env)
}
case api.MsgLogStream:
var p api.LogStreamLine
_ = env.UnmarshalPayload(&p)
if err := deps.Store.AppendJobLog(ctx, p.JobID, p.Seq, p.TS,
string(p.Stream), p.Payload); err != nil {
slog.Warn("ws: append job log", "job_id", p.JobID, "err", err)
}
if deps.JobHub != nil {
deps.JobHub.Broadcast(p.JobID, env)
}
case api.MsgSnapshotsRpt:
var p api.SnapshotsReportPayload
if err := env.UnmarshalPayload(&p); err != nil {
slog.Warn("ws: bad snapshots.report payload", "host_id", hostID, "err", err)
break
}
snaps := make([]store.Snapshot, len(p.Snapshots))
for i, s := range p.Snapshots {
snaps[i] = store.Snapshot{
ID: s.ID,
ShortID: s.ShortID,
Time: s.Time,
Hostname: s.Hostname,
Paths: s.Paths,
Tags: s.Tags,
SizeBytes: s.SizeBytes,
FileCount: s.FileCount,
}
}
if err := deps.Store.ReplaceHostSnapshots(ctx, hostID, snaps, time.Now().UTC()); err != nil {
slog.Warn("ws: replace snapshots", "host_id", hostID, "err", err)
} else {
slog.Info("ws: snapshots refreshed", "host_id", hostID, "count", len(snaps))
}
// A non-empty snapshot list also proves the repo is initialised
// (catches the case where an external job — `restic init` from
// the CLI, or a backup ran outside this control plane —
// initialised it before our first job dispatched).
if len(snaps) > 0 {
if _, err := deps.Store.MarkHostRepoInitialised(ctx, hostID, time.Now().UTC()); err != nil {
slog.Warn("ws: mark repo initialised (snapshots)", "host_id", hostID, "err", err)
}
}
case api.MsgScheduleAck:
var p api.ScheduleAckPayload
if err := env.UnmarshalPayload(&p); err != nil {
slog.Warn("ws: bad schedule.ack payload", "host_id", hostID, "err", err)
break
}
if deps.OnScheduleAck != nil {
deps.OnScheduleAck(ctx, hostID, p.Version, p.AppliedAt)
}
case api.MsgScheduleFire:
var p api.ScheduleFirePayload
if err := env.UnmarshalPayload(&p); err != nil {
slog.Warn("ws: bad schedule.fire payload", "host_id", hostID, "err", err)
break
}
if deps.OnScheduleFire != nil {
deps.OnScheduleFire(ctx, hostID, c, p.ScheduleID, p.ScheduledAt)
}
case api.MsgRepoStats, api.MsgCommandResult:
// TODO(P2): persist these projections.
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
case api.MsgError:
var ep api.ErrorPayload
_ = env.UnmarshalPayload(&ep)
slog.Warn("ws agent reported error", "host_id", hostID,
"code", string(ep.Code), "message", ep.Message)
default:
slog.Warn("ws unknown message type from agent",
"type", env.Type, "host_id", hostID)
}
}
// MinHeartbeatInterval is a sanity floor — any agent reporting
// heartbeats more often than this is misbehaving. (Spec says 30s.)
const MinHeartbeatInterval = 5 * time.Second
// suppress unused-import false-positives if json drops out later
var _ = json.Marshal