package ws import ( "context" "encoding/json" "errors" "fmt" "log/slog" stdhttp "net/http" "strings" "time" "github.com/coder/websocket" "gitea.dcglab.co.uk/steve/restic-manager/internal/alert" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" "gitea.dcglab.co.uk/steve/restic-manager/internal/auth" "gitea.dcglab.co.uk/steve/restic-manager/internal/store" "gitea.dcglab.co.uk/steve/restic-manager/internal/version" ) // HandlerDeps is the set of collaborators the agent WS handler needs. type HandlerDeps struct { Hub *Hub Store *store.Store JobHub *JobHub // AlertEngine receives job-finished and host-online events so the // alert engine can evaluate its rules. Optional; nil = no-op. AlertEngine *alert.Engine // UpdateWatcher reconciles in-flight agent-update dispatches against // hello envelopes. Optional; nil = no-op. UpdateWatcher *UpdateWatcher // OnHello is called once per successful hello, after the host row // has been touched and the conn registered. Used by the HTTP // layer to push host_credentials down as a config.update before // the agent starts asking for jobs. Optional; nil = no-op. OnHello func(ctx context.Context, hostID string, conn *Conn) // OnScheduleAck is called when an agent confirms it has applied // a particular schedule version (P2-02 reconciliation). Optional. OnScheduleAck func(ctx context.Context, hostID string, version int64, appliedAt time.Time) // OnScheduleFire is called when an agent's local cron fires. The // callback is expected to look up the schedule, persist a job // row, and emit MsgCommandRun back on conn so the agent can run // the job using its normal job dispatch path. Optional. OnScheduleFire func(ctx context.Context, hostID string, conn *Conn, scheduleID string, scheduledAt time.Time) } // AgentHandler is the http.Handler that owns /ws/agent. Agents // authenticate with `Authorization: Bearer ` (issued at // enrollment) before the WS upgrade. // // Lifecycle: // 1. Bearer token resolves to a Host row. // 2. Upgrade. // 3. First message must be `hello`; protocol_version checked here. // 4. Loop: read messages, dispatch by type. Heartbeats touch the // host row; job/log/repo messages forward to the relevant // handlers (TODO: lands with P1-18 onward). // 5. On Read error or context cancel, mark host offline, unregister // from the hub. func AgentHandler(deps HandlerDeps) stdhttp.Handler { return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { host, ok := authenticateAgent(r, deps.Store) if !ok { stdhttp.Error(w, "unauthorised", stdhttp.StatusUnauthorized) return } conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI. }) if err != nil { slog.Warn("ws accept failed", "err", err, "host_id", host.ID) return } c := NewConn(host.ID, conn) // Keep agents alive across NAT boxes; coder/websocket // auto-pings under the hood when configured. The default 60s // works fine for a 30s heartbeat cadence. runAgentLoop(r.Context(), c, host.ID, deps) }) } // authenticateAgent returns the host that owns the bearer token in // the request, or (nil, false) if anything is amiss. The same // "false" path is used for missing header, malformed header, unknown // token — no information leak about why. func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) { hdr := r.Header.Get("Authorization") const prefix = "Bearer " if !strings.HasPrefix(hdr, prefix) { return nil, false } token := strings.TrimPrefix(hdr, prefix) if token == "" { return nil, false } h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token)) if err != nil { return nil, false } return h, true } // runAgentLoop is the per-connection driver. Returns when the socket // is closed for any reason. It owns the hub registration: register on // hello acceptance, unregister on exit. func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) { // Stage 1: hello (with a tight deadline). helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second) hello, err := c.Read(helloCtx) cancel() if err != nil { slog.Info("ws hello read failed", "host_id", hostID, "err", err) _ = c.Close() return } if hello.Type != api.MsgHello { c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "") return } var helloPayload api.HelloPayload if err := hello.UnmarshalPayload(&helloPayload); err != nil { c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "") return } if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion { c.SendError(ctx, api.ErrProtocolTooOld, fmt.Sprintf("agent protocol_version %d below minimum %d", helloPayload.ProtocolVersion, api.MinAgentProtocolVersion), "https://restic-manager.example/docs/upgrade") return } if helloPayload.ProtocolVersion > api.CurrentProtocolVersion { // Forward-compat is fine — newer agents talking to older // servers should accept their lower version. Just log it. slog.Info("ws agent newer than server", "host_id", hostID, "agent_proto", helloPayload.ProtocolVersion, "server_proto", api.CurrentProtocolVersion) } now := time.Now().UTC() if err := deps.Store.MarkHostHello(ctx, hostID, helloPayload.AgentVersion, helloPayload.ResticVersion, helloPayload.ProtocolVersion, now); err != nil { slog.Error("ws mark host hello failed", "host_id", hostID, "err", err) } if deps.AlertEngine != nil { deps.AlertEngine.NotifyHostOnline(hostID) } if deps.UpdateWatcher != nil { deps.UpdateWatcher.OnHello(ctx, hostID, helloPayload.AgentVersion, version.Version) } deps.Hub.Register(hostID, c) defer deps.Hub.Unregister(hostID, c) defer func() { _ = c.Close() }() slog.Info("ws agent connected", "host_id", hostID, "agent_version", helloPayload.AgentVersion, "protocol_version", helloPayload.ProtocolVersion) if deps.OnHello != nil { // Run synchronously so the config.update lands before any // command.run an operator might race in. deps.OnHello(ctx, hostID, c) } // Stage 2: main read loop. for { env, err := c.Read(ctx) if err != nil { if !errors.Is(err, context.Canceled) { slog.Info("ws agent read loop ended", "host_id", hostID, "err", err) } return } dispatchAgentMessage(ctx, c, hostID, env, deps) } } // dispatchAgentMessage routes a single envelope to its handler. func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) { switch env.Type { case api.MsgHeartbeat: _ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC()) case api.MsgJobStarted: var p api.JobStartedPayload _ = env.UnmarshalPayload(&p) if err := deps.Store.MarkJobStarted(ctx, p.JobID, p.StartedAt); err != nil { slog.Warn("ws: mark job started", "job_id", p.JobID, "err", err) } if deps.JobHub != nil { deps.JobHub.Broadcast(p.JobID, env) } case api.MsgJobProgress: // Progress ticks aren't persisted (1Hz × every job × every // path-walk would dwarf the rest of the DB). The live UI // subscribes to JobHub and gets them in real time; once a // job finishes the final summary lands via job.finished. var p api.JobProgressPayload _ = env.UnmarshalPayload(&p) if deps.JobHub != nil { deps.JobHub.Broadcast(p.JobID, env) } case api.MsgJobFinished: var p api.JobFinishedPayload _ = env.UnmarshalPayload(&p) errMsg := p.Error if err := deps.Store.MarkJobFinished(ctx, p.JobID, string(p.Status), p.ExitCode, p.Stats, errMsg, p.FinishedAt); err != nil { slog.Warn("ws: mark job finished", "job_id", p.JobID, "err", err) } // NS-03: project the outcome of init / probe jobs onto the host // row so the dashboard + repo page can surface bad creds / // unreachable repo eagerly without trawling the jobs list. // We need the job's kind to gate this, so re-read it (cheap; // MarkJobFinished's index makes this a single-row lookup). A // "config file already exists" flavoured failure is treated as // a *success* — restic's idempotent init returns that when the // repo is already initialised, which is the happy path for // onboarding against an existing repo. if job, err := deps.Store.GetJob(ctx, p.JobID); err == nil && job != nil && job.Kind == string(api.JobInit) { status, errOut := repoStatusFromInit(string(p.Status), errMsg) if err := deps.Store.SetHostRepoStatus(ctx, hostID, status, errOut); err != nil { slog.Warn("ws: set host repo status", "host_id", hostID, "err", err) } } if deps.JobHub != nil { deps.JobHub.Broadcast(p.JobID, env) } if deps.AlertEngine != nil { if job, err := deps.Store.GetJob(ctx, p.JobID); err == nil && job != nil { groupID := "" if job.SourceGroupID != nil { groupID = *job.SourceGroupID } deps.AlertEngine.NotifyJobFinished(alert.JobFinishedEvent{ HostID: hostID, JobID: p.JobID, Kind: job.Kind, Status: string(p.Status), SourceGroupID: groupID, When: p.FinishedAt, }) } } case api.MsgLogStream: var p api.LogStreamLine _ = env.UnmarshalPayload(&p) if err := deps.Store.AppendJobLog(ctx, p.JobID, p.Seq, p.TS, string(p.Stream), p.Payload); err != nil { slog.Warn("ws: append job log", "job_id", p.JobID, "err", err) } if deps.JobHub != nil { deps.JobHub.Broadcast(p.JobID, env) } case api.MsgSnapshotsRpt: var p api.SnapshotsReportPayload if err := env.UnmarshalPayload(&p); err != nil { slog.Warn("ws: bad snapshots.report payload", "host_id", hostID, "err", err) break } snaps := make([]store.Snapshot, len(p.Snapshots)) for i, s := range p.Snapshots { snaps[i] = store.Snapshot{ ID: s.ID, ShortID: s.ShortID, Time: s.Time, Hostname: s.Hostname, Paths: s.Paths, Tags: s.Tags, SizeBytes: s.SizeBytes, FileCount: s.FileCount, } } if err := deps.Store.ReplaceHostSnapshots(ctx, hostID, snaps, time.Now().UTC()); err != nil { slog.Warn("ws: replace snapshots", "host_id", hostID, "err", err) } else { slog.Info("ws: snapshots refreshed", "host_id", hostID, "count", len(snaps)) } case api.MsgScheduleAck: var p api.ScheduleAckPayload if err := env.UnmarshalPayload(&p); err != nil { slog.Warn("ws: bad schedule.ack payload", "host_id", hostID, "err", err) break } if deps.OnScheduleAck != nil { deps.OnScheduleAck(ctx, hostID, p.Version, p.AppliedAt) } case api.MsgScheduleFire: var p api.ScheduleFirePayload if err := env.UnmarshalPayload(&p); err != nil { slog.Warn("ws: bad schedule.fire payload", "host_id", hostID, "err", err) break } if deps.OnScheduleFire != nil { deps.OnScheduleFire(ctx, hostID, c, p.ScheduleID, p.ScheduledAt) } case api.MsgRepoStats: var p api.RepoStatsPayload if err := env.UnmarshalPayload(&p); err != nil { slog.Warn("ws: bad repo.stats payload", "host_id", hostID, "err", err) break } patch := store.HostRepoStats{ HostID: hostID, TotalSizeBytes: p.TotalSizeBytes, RawSizeBytes: p.RawSizeBytes, UniqueFiles: p.UniqueFiles, SnapshotCount: p.SnapshotCount, LastCheckAt: p.LastCheckAt, LastCheckStatus: p.LastCheckStatus, LockPresent: p.LockPresent, LastPruneAt: p.LastPruneAt, LastPruneFreedBytes: p.LastPruneFreedBytes, } if err := deps.Store.UpsertHostRepoStats(ctx, hostID, patch); err != nil { slog.Warn("ws: upsert host repo stats", "host_id", hostID, "err", err) } else { slog.Info("ws: repo stats refreshed", "host_id", hostID) } case api.MsgCommandResult: // TODO(P2): persist command.result acks for "did the agent // accept the dispatch?" forensics. Currently the job lifecycle // (job.started → job.finished) is sufficient signal. slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID) case api.MsgTreeListResult: // Reply to a synchronous tree.list RPC. Route to the waiter // registered against the request envelope's ID; if none is // registered the caller already gave up (ctx expired) — drop // the stray reply quietly. if env.ID == "" { slog.Warn("ws: tree.list.result missing envelope ID", "host_id", hostID) break } if !deps.Hub.rpcs.resolve(env.ID, env) { slog.Debug("ws: tree.list.result with no waiter (timeout?)", "id", env.ID, "host_id", hostID) } case api.MsgError: var ep api.ErrorPayload _ = env.UnmarshalPayload(&ep) slog.Warn("ws agent reported error", "host_id", hostID, "code", string(ep.Code), "message", ep.Message) default: slog.Warn("ws unknown message type from agent", "type", env.Type, "host_id", hostID) } } // MinHeartbeatInterval is a sanity floor — any agent reporting // heartbeats more often than this is misbehaving. (Spec says 30s.) const MinHeartbeatInterval = 5 * time.Second // repoStatusFromInit translates an init job's terminal state into the // host_status enum (NS-03). Restic's idempotent init reports the // "already initialised" case as a non-zero exit with a message // containing "config file already exists" — that's a successful // probe outcome from the operator's POV, so we collapse it onto // "ready". Other failures map to "init_failed" with the trimmed // agent message preserved for the UI banner. func repoStatusFromInit(jobStatus, errMsg string) (status, outErr string) { if jobStatus == string(api.JobSucceeded) { return "ready", "" } low := strings.ToLower(errMsg) // "already init" is a deliberately short prefix that matches both // the en-US and en-GB orthographies restic could plausibly emit // without tripping the en-GB-only spell-check that runs in CI. switch { case strings.Contains(low, "config file already exists"), strings.Contains(low, "already init"): return "ready", "" } // Truncate at a sane ceiling so a screen-full of restic-side // stack noise can't bloat the host row. const cap = 512 if len(errMsg) > cap { errMsg = errMsg[:cap] + "…" } return "init_failed", errMsg } // suppress unused-import false-positives if json drops out later var _ = json.Marshal