414 lines
14 KiB
Go
414 lines
14 KiB
Go
package ws
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log/slog"
|
||
stdhttp "net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/coder/websocket"
|
||
|
||
"gitea.dcglab.co.uk/steve/restic-manager/internal/alert"
|
||
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
||
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||
"gitea.dcglab.co.uk/steve/restic-manager/internal/version"
|
||
)
|
||
|
||
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
||
type HandlerDeps struct {
|
||
Hub *Hub
|
||
Store *store.Store
|
||
JobHub *JobHub
|
||
// AlertEngine receives job-finished and host-online events so the
|
||
// alert engine can evaluate its rules. Optional; nil = no-op.
|
||
AlertEngine *alert.Engine
|
||
// UpdateWatcher reconciles in-flight agent-update dispatches against
|
||
// hello envelopes. Optional; nil = no-op.
|
||
UpdateWatcher *UpdateWatcher
|
||
// OnHello is called once per successful hello, after the host row
|
||
// has been touched and the conn registered. Used by the HTTP
|
||
// layer to push host_credentials down as a config.update before
|
||
// the agent starts asking for jobs. Optional; nil = no-op.
|
||
OnHello func(ctx context.Context, hostID string, conn *Conn)
|
||
// OnScheduleAck is called when an agent confirms it has applied
|
||
// a particular schedule version (P2-02 reconciliation). Optional.
|
||
OnScheduleAck func(ctx context.Context, hostID string, version int64, appliedAt time.Time)
|
||
// OnScheduleFire is called when an agent's local cron fires. The
|
||
// callback is expected to look up the schedule, persist a job
|
||
// row, and emit MsgCommandRun back on conn so the agent can run
|
||
// the job using its normal job dispatch path. Optional.
|
||
OnScheduleFire func(ctx context.Context, hostID string, conn *Conn, scheduleID string, scheduledAt time.Time)
|
||
}
|
||
|
||
// AgentHandler is the http.Handler that owns /ws/agent. Agents
|
||
// authenticate with `Authorization: Bearer <token>` (issued at
|
||
// enrollment) before the WS upgrade.
|
||
//
|
||
// Lifecycle:
|
||
// 1. Bearer token resolves to a Host row.
|
||
// 2. Upgrade.
|
||
// 3. First message must be `hello`; protocol_version checked here.
|
||
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
|
||
// host row; job/log/repo messages forward to the relevant
|
||
// handlers (TODO: lands with P1-18 onward).
|
||
// 5. On Read error or context cancel, mark host offline, unregister
|
||
// from the hub.
|
||
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
|
||
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
||
host, ok := authenticateAgent(r, deps.Store)
|
||
if !ok {
|
||
stdhttp.Error(w, "unauthorised", stdhttp.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
|
||
})
|
||
if err != nil {
|
||
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
|
||
return
|
||
}
|
||
|
||
c := NewConn(host.ID, conn)
|
||
// Keep agents alive across NAT boxes; coder/websocket
|
||
// auto-pings under the hood when configured. The default 60s
|
||
// works fine for a 30s heartbeat cadence.
|
||
|
||
runAgentLoop(r.Context(), c, host.ID, deps)
|
||
})
|
||
}
|
||
|
||
// authenticateAgent returns the host that owns the bearer token in
|
||
// the request, or (nil, false) if anything is amiss. The same
|
||
// "false" path is used for missing header, malformed header, unknown
|
||
// token — no information leak about why.
|
||
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
|
||
hdr := r.Header.Get("Authorization")
|
||
const prefix = "Bearer "
|
||
if !strings.HasPrefix(hdr, prefix) {
|
||
return nil, false
|
||
}
|
||
token := strings.TrimPrefix(hdr, prefix)
|
||
if token == "" {
|
||
return nil, false
|
||
}
|
||
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
|
||
if err != nil {
|
||
return nil, false
|
||
}
|
||
return h, true
|
||
}
|
||
|
||
// runAgentLoop is the per-connection driver. Returns when the socket
|
||
// is closed for any reason. It owns the hub registration: register on
|
||
// hello acceptance, unregister on exit.
|
||
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
|
||
// Stage 1: hello (with a tight deadline).
|
||
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||
hello, err := c.Read(helloCtx)
|
||
cancel()
|
||
if err != nil {
|
||
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
|
||
_ = c.Close()
|
||
return
|
||
}
|
||
if hello.Type != api.MsgHello {
|
||
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
|
||
return
|
||
}
|
||
var helloPayload api.HelloPayload
|
||
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
|
||
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
|
||
return
|
||
}
|
||
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
|
||
c.SendError(ctx, api.ErrProtocolTooOld,
|
||
fmt.Sprintf("agent protocol_version %d below minimum %d",
|
||
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
|
||
"https://restic-manager.example/docs/upgrade")
|
||
return
|
||
}
|
||
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
|
||
// Forward-compat is fine — newer agents talking to older
|
||
// servers should accept their lower version. Just log it.
|
||
slog.Info("ws agent newer than server",
|
||
"host_id", hostID,
|
||
"agent_proto", helloPayload.ProtocolVersion,
|
||
"server_proto", api.CurrentProtocolVersion)
|
||
}
|
||
|
||
now := time.Now().UTC()
|
||
if err := deps.Store.MarkHostHello(ctx, hostID,
|
||
helloPayload.AgentVersion, helloPayload.ResticVersion,
|
||
helloPayload.ProtocolVersion, now); err != nil {
|
||
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
|
||
}
|
||
if deps.AlertEngine != nil {
|
||
deps.AlertEngine.NotifyHostOnline(hostID)
|
||
}
|
||
if deps.UpdateWatcher != nil {
|
||
deps.UpdateWatcher.OnHello(ctx, hostID, helloPayload.AgentVersion, version.Version)
|
||
}
|
||
|
||
deps.Hub.Register(hostID, c)
|
||
defer deps.Hub.Unregister(hostID, c)
|
||
defer func() { _ = c.Close() }()
|
||
|
||
slog.Info("ws agent connected",
|
||
"host_id", hostID,
|
||
"agent_version", helloPayload.AgentVersion,
|
||
"protocol_version", helloPayload.ProtocolVersion)
|
||
|
||
if deps.OnHello != nil {
|
||
// Run synchronously so the config.update lands before any
|
||
// command.run an operator might race in.
|
||
deps.OnHello(ctx, hostID, c)
|
||
}
|
||
|
||
// Stage 2: main read loop.
|
||
for {
|
||
env, err := c.Read(ctx)
|
||
if err != nil {
|
||
if !errors.Is(err, context.Canceled) {
|
||
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
|
||
}
|
||
return
|
||
}
|
||
dispatchAgentMessage(ctx, c, hostID, env, deps)
|
||
}
|
||
}
|
||
|
||
// dispatchAgentMessage routes a single envelope to its handler.
|
||
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
|
||
switch env.Type {
|
||
case api.MsgHeartbeat:
|
||
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
|
||
|
||
case api.MsgJobStarted:
|
||
var p api.JobStartedPayload
|
||
_ = env.UnmarshalPayload(&p)
|
||
if err := deps.Store.MarkJobStarted(ctx, p.JobID, p.StartedAt); err != nil {
|
||
slog.Warn("ws: mark job started", "job_id", p.JobID, "err", err)
|
||
}
|
||
if deps.JobHub != nil {
|
||
deps.JobHub.Broadcast(p.JobID, env)
|
||
}
|
||
|
||
case api.MsgJobProgress:
|
||
// Progress ticks aren't persisted (1Hz × every job × every
|
||
// path-walk would dwarf the rest of the DB). The live UI
|
||
// subscribes to JobHub and gets them in real time; once a
|
||
// job finishes the final summary lands via job.finished.
|
||
var p api.JobProgressPayload
|
||
_ = env.UnmarshalPayload(&p)
|
||
if deps.JobHub != nil {
|
||
deps.JobHub.Broadcast(p.JobID, env)
|
||
}
|
||
|
||
case api.MsgJobFinished:
|
||
var p api.JobFinishedPayload
|
||
_ = env.UnmarshalPayload(&p)
|
||
errMsg := p.Error
|
||
if err := deps.Store.MarkJobFinished(ctx, p.JobID,
|
||
string(p.Status), p.ExitCode, p.Stats, errMsg, p.FinishedAt); err != nil {
|
||
slog.Warn("ws: mark job finished", "job_id", p.JobID, "err", err)
|
||
}
|
||
// NS-03: project the outcome of init / probe jobs onto the host
|
||
// row so the dashboard + repo page can surface bad creds /
|
||
// unreachable repo eagerly without trawling the jobs list.
|
||
// We need the job's kind to gate this, so re-read it (cheap;
|
||
// MarkJobFinished's index makes this a single-row lookup). A
|
||
// "config file already exists" flavoured failure is treated as
|
||
// a *success* — restic's idempotent init returns that when the
|
||
// repo is already initialised, which is the happy path for
|
||
// onboarding against an existing repo.
|
||
if job, err := deps.Store.GetJob(ctx, p.JobID); err == nil && job != nil {
|
||
switch job.Kind {
|
||
case string(api.JobInit):
|
||
status, errOut := repoStatusFromInit(string(p.Status), errMsg)
|
||
if err := deps.Store.SetHostRepoStatus(ctx, hostID, status, errOut); err != nil {
|
||
slog.Warn("ws: set host repo status", "host_id", hostID, "err", err)
|
||
}
|
||
case string(api.JobBackup):
|
||
if err := deps.Store.SetHostLastBackup(ctx, hostID, string(p.Status), p.FinishedAt); err != nil {
|
||
slog.Warn("ws: set host last backup", "host_id", hostID, "err", err)
|
||
}
|
||
}
|
||
}
|
||
if deps.JobHub != nil {
|
||
deps.JobHub.Broadcast(p.JobID, env)
|
||
}
|
||
if deps.AlertEngine != nil {
|
||
if job, err := deps.Store.GetJob(ctx, p.JobID); err == nil && job != nil {
|
||
groupID := ""
|
||
if job.SourceGroupID != nil {
|
||
groupID = *job.SourceGroupID
|
||
}
|
||
deps.AlertEngine.NotifyJobFinished(alert.JobFinishedEvent{
|
||
HostID: hostID,
|
||
JobID: p.JobID,
|
||
Kind: job.Kind,
|
||
Status: string(p.Status),
|
||
SourceGroupID: groupID,
|
||
When: p.FinishedAt,
|
||
})
|
||
}
|
||
}
|
||
|
||
case api.MsgLogStream:
|
||
var p api.LogStreamLine
|
||
_ = env.UnmarshalPayload(&p)
|
||
if err := deps.Store.AppendJobLog(ctx, p.JobID, p.Seq, p.TS,
|
||
string(p.Stream), p.Payload); err != nil {
|
||
slog.Warn("ws: append job log", "job_id", p.JobID, "err", err)
|
||
}
|
||
if deps.JobHub != nil {
|
||
deps.JobHub.Broadcast(p.JobID, env)
|
||
}
|
||
|
||
case api.MsgSnapshotsRpt:
|
||
var p api.SnapshotsReportPayload
|
||
if err := env.UnmarshalPayload(&p); err != nil {
|
||
slog.Warn("ws: bad snapshots.report payload", "host_id", hostID, "err", err)
|
||
break
|
||
}
|
||
snaps := make([]store.Snapshot, len(p.Snapshots))
|
||
for i, s := range p.Snapshots {
|
||
snaps[i] = store.Snapshot{
|
||
ID: s.ID,
|
||
ShortID: s.ShortID,
|
||
Time: s.Time,
|
||
Hostname: s.Hostname,
|
||
Paths: s.Paths,
|
||
Tags: s.Tags,
|
||
SizeBytes: s.SizeBytes,
|
||
FileCount: s.FileCount,
|
||
}
|
||
}
|
||
if err := deps.Store.ReplaceHostSnapshots(ctx, hostID, snaps, time.Now().UTC()); err != nil {
|
||
slog.Warn("ws: replace snapshots", "host_id", hostID, "err", err)
|
||
} else {
|
||
slog.Info("ws: snapshots refreshed", "host_id", hostID, "count", len(snaps))
|
||
}
|
||
|
||
case api.MsgScheduleAck:
|
||
var p api.ScheduleAckPayload
|
||
if err := env.UnmarshalPayload(&p); err != nil {
|
||
slog.Warn("ws: bad schedule.ack payload", "host_id", hostID, "err", err)
|
||
break
|
||
}
|
||
if deps.OnScheduleAck != nil {
|
||
deps.OnScheduleAck(ctx, hostID, p.Version, p.AppliedAt)
|
||
}
|
||
|
||
case api.MsgScheduleFire:
|
||
var p api.ScheduleFirePayload
|
||
if err := env.UnmarshalPayload(&p); err != nil {
|
||
slog.Warn("ws: bad schedule.fire payload", "host_id", hostID, "err", err)
|
||
break
|
||
}
|
||
if deps.OnScheduleFire != nil {
|
||
deps.OnScheduleFire(ctx, hostID, c, p.ScheduleID, p.ScheduledAt)
|
||
}
|
||
|
||
case api.MsgRepoStats:
|
||
var p api.RepoStatsPayload
|
||
if err := env.UnmarshalPayload(&p); err != nil {
|
||
slog.Warn("ws: bad repo.stats payload", "host_id", hostID, "err", err)
|
||
break
|
||
}
|
||
patch := store.HostRepoStats{
|
||
HostID: hostID,
|
||
TotalSizeBytes: p.TotalSizeBytes,
|
||
RawSizeBytes: p.RawSizeBytes,
|
||
UniqueFiles: p.UniqueFiles,
|
||
SnapshotCount: p.SnapshotCount,
|
||
LastCheckAt: p.LastCheckAt,
|
||
LastCheckStatus: p.LastCheckStatus,
|
||
LockPresent: p.LockPresent,
|
||
LastPruneAt: p.LastPruneAt,
|
||
LastPruneFreedBytes: p.LastPruneFreedBytes,
|
||
}
|
||
if err := deps.Store.UpsertHostRepoStats(ctx, hostID, patch); err != nil {
|
||
slog.Warn("ws: upsert host repo stats", "host_id", hostID, "err", err)
|
||
} else {
|
||
slog.Info("ws: repo stats refreshed", "host_id", hostID)
|
||
}
|
||
day := time.Now().UTC().Format("2006-01-02")
|
||
if err := deps.Store.UpsertHostRepoStatsHistory(ctx, hostID, day, patch, time.Now().UTC()); err != nil {
|
||
slog.Warn("ws: upsert host repo stats history", "host_id", hostID, "err", err)
|
||
}
|
||
|
||
case api.MsgCommandResult:
|
||
// TODO(P2): persist command.result acks for "did the agent
|
||
// accept the dispatch?" forensics. Currently the job lifecycle
|
||
// (job.started → job.finished) is sufficient signal.
|
||
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
|
||
|
||
case api.MsgTreeListResult:
|
||
// Reply to a synchronous tree.list RPC. Route to the waiter
|
||
// registered against the request envelope's ID; if none is
|
||
// registered the caller already gave up (ctx expired) — drop
|
||
// the stray reply quietly.
|
||
if env.ID == "" {
|
||
slog.Warn("ws: tree.list.result missing envelope ID", "host_id", hostID)
|
||
break
|
||
}
|
||
if !deps.Hub.rpcs.resolve(env.ID, env) {
|
||
slog.Debug("ws: tree.list.result with no waiter (timeout?)",
|
||
"id", env.ID, "host_id", hostID)
|
||
}
|
||
|
||
case api.MsgError:
|
||
var ep api.ErrorPayload
|
||
_ = env.UnmarshalPayload(&ep)
|
||
slog.Warn("ws agent reported error", "host_id", hostID,
|
||
"code", string(ep.Code), "message", ep.Message)
|
||
|
||
default:
|
||
slog.Warn("ws unknown message type from agent",
|
||
"type", env.Type, "host_id", hostID)
|
||
}
|
||
}
|
||
|
||
// MinHeartbeatInterval is a sanity floor — any agent reporting
|
||
// heartbeats more often than this is misbehaving. (Spec says 30s.)
|
||
const MinHeartbeatInterval = 5 * time.Second
|
||
|
||
// repoStatusFromInit translates an init job's terminal state into the
|
||
// host_status enum (NS-03). Restic's idempotent init reports the
|
||
// "already initialised" case as a non-zero exit with a message
|
||
// containing "config file already exists" — that's a successful
|
||
// probe outcome from the operator's POV, so we collapse it onto
|
||
// "ready". Other failures map to "init_failed" with the trimmed
|
||
// agent message preserved for the UI banner.
|
||
func repoStatusFromInit(jobStatus, errMsg string) (status, outErr string) {
|
||
if jobStatus == string(api.JobSucceeded) {
|
||
return "ready", ""
|
||
}
|
||
low := strings.ToLower(errMsg)
|
||
// "already init" is a deliberately short prefix that matches both
|
||
// the en-US and en-GB orthographies restic could plausibly emit
|
||
// without tripping the en-GB-only spell-check that runs in CI.
|
||
switch {
|
||
case strings.Contains(low, "config file already exists"),
|
||
strings.Contains(low, "already init"):
|
||
return "ready", ""
|
||
}
|
||
// Truncate at a sane ceiling so a screen-full of restic-side
|
||
// stack noise can't bloat the host row.
|
||
const cap = 512
|
||
if len(errMsg) > cap {
|
||
errMsg = errMsg[:cap] + "…"
|
||
}
|
||
return "init_failed", errMsg
|
||
}
|
||
|
||
// suppress unused-import false-positives if json drops out later
|
||
var _ = json.Marshal
|