b3b89045f2
Operator-minted enrollment tokens now carry the repo URL/username/
password as one AEAD blob bound (via additional-data) to the token
hash. ConsumeEnrollmentToken re-encrypts under host_id and writes a
host_credentials row in the same tx as token-burn, so the binding
moves with the credential.
PUT /api/hosts/{id}/repo-credentials lets an operator edit creds
post-enrollment; merges with the existing blob, audits, and pushes
config.update if the agent is connected.
WS handler grows an OnHello hook that the HTTP layer wires to send
the host's decrypted creds as a config.update immediately after the
hello succeeds — synchronously, so a racing command.run lands after
the agent has its repo password.
Schema: 0002_host_credentials.sql adds enc_repo_creds to
enrollment_tokens and a host_credentials table (PK = host_id, FK
ON DELETE CASCADE).
Tests: round-trip token → consume → host_credentials with AAD swap
detection; no-creds path stays compatible.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
246 lines
7.9 KiB
Go
246 lines
7.9 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
stdhttp "net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/api"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/auth"
|
|
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
|
)
|
|
|
|
// HandlerDeps is the set of collaborators the agent WS handler needs.
|
|
type HandlerDeps struct {
|
|
Hub *Hub
|
|
Store *store.Store
|
|
// OnHello is called once per successful hello, after the host row
|
|
// has been touched and the conn registered. Used by the HTTP
|
|
// layer to push host_credentials down as a config.update before
|
|
// the agent starts asking for jobs. Optional; nil = no-op.
|
|
OnHello func(ctx context.Context, hostID string, conn *Conn)
|
|
}
|
|
|
|
// AgentHandler is the http.Handler that owns /ws/agent. Agents
|
|
// authenticate with `Authorization: Bearer <token>` (issued at
|
|
// enrollment) before the WS upgrade.
|
|
//
|
|
// Lifecycle:
|
|
// 1. Bearer token resolves to a Host row.
|
|
// 2. Upgrade.
|
|
// 3. First message must be `hello`; protocol_version checked here.
|
|
// 4. Loop: read messages, dispatch by type. Heartbeats touch the
|
|
// host row; job/log/repo messages forward to the relevant
|
|
// handlers (TODO: lands with P1-18 onward).
|
|
// 5. On Read error or context cancel, mark host offline, unregister
|
|
// from the hub.
|
|
func AgentHandler(deps HandlerDeps) stdhttp.Handler {
|
|
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
|
|
host, ok := authenticateAgent(r, deps.Store)
|
|
if !ok {
|
|
stdhttp.Error(w, "unauthorized", stdhttp.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
InsecureSkipVerify: true, // Origin checks are pointless for an agent CLI.
|
|
})
|
|
if err != nil {
|
|
slog.Warn("ws accept failed", "err", err, "host_id", host.ID)
|
|
return
|
|
}
|
|
|
|
c := NewConn(host.ID, conn)
|
|
// Keep agents alive across NAT boxes; coder/websocket
|
|
// auto-pings under the hood when configured. The default 60s
|
|
// works fine for a 30s heartbeat cadence.
|
|
|
|
runAgentLoop(r.Context(), c, host.ID, deps)
|
|
})
|
|
}
|
|
|
|
// authenticateAgent returns the host that owns the bearer token in
|
|
// the request, or (nil, false) if anything is amiss. The same
|
|
// "false" path is used for missing header, malformed header, unknown
|
|
// token — no information leak about why.
|
|
func authenticateAgent(r *stdhttp.Request, st *store.Store) (*store.Host, bool) {
|
|
hdr := r.Header.Get("Authorization")
|
|
const prefix = "Bearer "
|
|
if !strings.HasPrefix(hdr, prefix) {
|
|
return nil, false
|
|
}
|
|
token := strings.TrimPrefix(hdr, prefix)
|
|
if token == "" {
|
|
return nil, false
|
|
}
|
|
h, err := st.LookupHostByAgentToken(r.Context(), auth.HashToken(token))
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
return h, true
|
|
}
|
|
|
|
// runAgentLoop is the per-connection driver. Returns when the socket
|
|
// is closed for any reason. It owns the hub registration: register on
|
|
// hello acceptance, unregister on exit.
|
|
func runAgentLoop(ctx context.Context, c *Conn, hostID string, deps HandlerDeps) {
|
|
// Stage 1: hello (with a tight deadline).
|
|
helloCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
hello, err := c.Read(helloCtx)
|
|
cancel()
|
|
if err != nil {
|
|
slog.Info("ws hello read failed", "host_id", hostID, "err", err)
|
|
_ = c.Close()
|
|
return
|
|
}
|
|
if hello.Type != api.MsgHello {
|
|
c.SendError(ctx, api.ErrBadRequest, "first message must be hello", "")
|
|
return
|
|
}
|
|
var helloPayload api.HelloPayload
|
|
if err := hello.UnmarshalPayload(&helloPayload); err != nil {
|
|
c.SendError(ctx, api.ErrBadRequest, "malformed hello payload", "")
|
|
return
|
|
}
|
|
if helloPayload.ProtocolVersion < api.MinAgentProtocolVersion {
|
|
c.SendError(ctx, api.ErrProtocolTooOld,
|
|
fmt.Sprintf("agent protocol_version %d below minimum %d",
|
|
helloPayload.ProtocolVersion, api.MinAgentProtocolVersion),
|
|
"https://restic-manager.example/docs/upgrade")
|
|
return
|
|
}
|
|
if helloPayload.ProtocolVersion > api.CurrentProtocolVersion {
|
|
// Forward-compat is fine — newer agents talking to older
|
|
// servers should accept their lower version. Just log it.
|
|
slog.Info("ws agent newer than server",
|
|
"host_id", hostID,
|
|
"agent_proto", helloPayload.ProtocolVersion,
|
|
"server_proto", api.CurrentProtocolVersion)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
if err := deps.Store.MarkHostHello(ctx, hostID,
|
|
helloPayload.AgentVersion, helloPayload.ResticVersion,
|
|
helloPayload.ProtocolVersion, now); err != nil {
|
|
slog.Error("ws mark host hello failed", "host_id", hostID, "err", err)
|
|
}
|
|
|
|
deps.Hub.Register(hostID, c)
|
|
defer deps.Hub.Unregister(hostID, c)
|
|
defer func() { _ = c.Close() }()
|
|
|
|
slog.Info("ws agent connected",
|
|
"host_id", hostID,
|
|
"agent_version", helloPayload.AgentVersion,
|
|
"protocol_version", helloPayload.ProtocolVersion)
|
|
|
|
if deps.OnHello != nil {
|
|
// Run synchronously so the config.update lands before any
|
|
// command.run an operator might race in.
|
|
deps.OnHello(ctx, hostID, c)
|
|
}
|
|
|
|
// Stage 2: main read loop.
|
|
for {
|
|
env, err := c.Read(ctx)
|
|
if err != nil {
|
|
if !errors.Is(err, context.Canceled) {
|
|
slog.Info("ws agent read loop ended", "host_id", hostID, "err", err)
|
|
}
|
|
return
|
|
}
|
|
dispatchAgentMessage(ctx, c, hostID, env, deps)
|
|
}
|
|
}
|
|
|
|
// dispatchAgentMessage routes a single envelope to its handler.
|
|
func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.Envelope, deps HandlerDeps) {
|
|
switch env.Type {
|
|
case api.MsgHeartbeat:
|
|
_ = deps.Store.TouchHost(ctx, hostID, time.Now().UTC())
|
|
|
|
case api.MsgJobStarted:
|
|
var p api.JobStartedPayload
|
|
_ = env.UnmarshalPayload(&p)
|
|
if err := deps.Store.MarkJobStarted(ctx, p.JobID, p.StartedAt); err != nil {
|
|
slog.Warn("ws: mark job started", "job_id", p.JobID, "err", err)
|
|
}
|
|
|
|
case api.MsgJobProgress:
|
|
// We don't persist every progress tick; the live UI subscribes
|
|
// to a fan-out channel that lands with P1-21 / the UI work.
|
|
// TODO: implement the ws fan-out hub for browsers.
|
|
_ = env
|
|
|
|
case api.MsgJobFinished:
|
|
var p api.JobFinishedPayload
|
|
_ = env.UnmarshalPayload(&p)
|
|
errMsg := p.Error
|
|
if err := deps.Store.MarkJobFinished(ctx, p.JobID,
|
|
string(p.Status), p.ExitCode, p.Stats, errMsg, p.FinishedAt); err != nil {
|
|
slog.Warn("ws: mark job finished", "job_id", p.JobID, "err", err)
|
|
}
|
|
|
|
case api.MsgLogStream:
|
|
var p api.LogStreamLine
|
|
_ = env.UnmarshalPayload(&p)
|
|
if err := deps.Store.AppendJobLog(ctx, p.JobID, p.Seq, p.TS,
|
|
string(p.Stream), p.Payload); err != nil {
|
|
slog.Warn("ws: append job log", "job_id", p.JobID, "err", err)
|
|
}
|
|
|
|
case api.MsgSnapshotsRpt:
|
|
var p api.SnapshotsReportPayload
|
|
if err := env.UnmarshalPayload(&p); err != nil {
|
|
slog.Warn("ws: bad snapshots.report payload", "host_id", hostID, "err", err)
|
|
break
|
|
}
|
|
snaps := make([]store.Snapshot, len(p.Snapshots))
|
|
for i, s := range p.Snapshots {
|
|
snaps[i] = store.Snapshot{
|
|
ID: s.ID,
|
|
ShortID: s.ShortID,
|
|
Time: s.Time,
|
|
Hostname: s.Hostname,
|
|
Paths: s.Paths,
|
|
Tags: s.Tags,
|
|
SizeBytes: s.SizeBytes,
|
|
FileCount: s.FileCount,
|
|
}
|
|
}
|
|
if err := deps.Store.ReplaceHostSnapshots(ctx, hostID, snaps, time.Now().UTC()); err != nil {
|
|
slog.Warn("ws: replace snapshots", "host_id", hostID, "err", err)
|
|
} else {
|
|
slog.Info("ws: snapshots refreshed", "host_id", hostID, "count", len(snaps))
|
|
}
|
|
|
|
case api.MsgRepoStats, api.MsgScheduleAck, api.MsgCommandResult:
|
|
// TODO(P2): persist these projections.
|
|
slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID)
|
|
|
|
case api.MsgError:
|
|
var ep api.ErrorPayload
|
|
_ = env.UnmarshalPayload(&ep)
|
|
slog.Warn("ws agent reported error", "host_id", hostID,
|
|
"code", string(ep.Code), "message", ep.Message)
|
|
|
|
default:
|
|
slog.Warn("ws unknown message type from agent",
|
|
"type", env.Type, "host_id", hostID)
|
|
}
|
|
}
|
|
|
|
// MinHeartbeatInterval is a sanity floor — any agent reporting
|
|
// heartbeats more often than this is misbehaving. (Spec says 30s.)
|
|
const MinHeartbeatInterval = 5 * time.Second
|
|
|
|
// suppress unused-import false-positives if json drops out later
|
|
var _ = json.Marshal
|