// Package ws hosts the WebSocket transport for agent ↔ server. The // Hub tracks one active connection per host id; subsequent connections // from the same host evict the prior one (last-write-wins). package ws import ( "context" "encoding/json" "errors" "fmt" "log/slog" "sync" "time" "github.com/coder/websocket" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" ) // Hub owns the live agent connections and routes messages. type Hub struct { mu sync.RWMutex conns map[string]*Conn // hostID → conn } // NewHub returns an empty hub. func NewHub() *Hub { return &Hub{conns: make(map[string]*Conn)} } // Conn is one agent WS connection. Send is safe for concurrent use; // Read is single-reader (the connection's run loop). type Conn struct { HostID string c *websocket.Conn writeMu sync.Mutex } // Register installs c as the canonical connection for hostID. Any // previous connection for that host is closed. func (h *Hub) Register(hostID string, c *Conn) { h.mu.Lock() if prev, ok := h.conns[hostID]; ok { // Best-effort close — a stuck old socket shouldn't block new one. go func(old *Conn) { _ = old.c.Close(websocket.StatusPolicyViolation, "superseded") }(prev) } h.conns[hostID] = c h.mu.Unlock() } // Unregister removes c iff it is still the canonical conn (a race // where a newer conn already replaced it must not unregister it). func (h *Hub) Unregister(hostID string, c *Conn) { h.mu.Lock() if cur, ok := h.conns[hostID]; ok && cur == c { delete(h.conns, hostID) } h.mu.Unlock() } // Send delivers an envelope to the host if connected. Returns an error // if the host is offline; caller may queue the message for later. func (h *Hub) Send(ctx context.Context, hostID string, env api.Envelope) error { h.mu.RLock() c, ok := h.conns[hostID] h.mu.RUnlock() if !ok { return fmt.Errorf("ws: host %q is offline", hostID) } return c.Send(ctx, env) } // Connected reports whether hostID has an active connection. func (h *Hub) Connected(hostID string) bool { h.mu.RLock() _, ok := h.conns[hostID] h.mu.RUnlock() return ok } // Conn returns the canonical connection for hostID, or nil if the // host is offline. Tests use this to obtain a *Conn for direct calls // into handlers that take one. Production code should prefer Send, // which avoids holding a reference past the point where a supersede // might have replaced the conn. func (h *Hub) Conn(hostID string) *Conn { h.mu.RLock() defer h.mu.RUnlock() return h.conns[hostID] } // ----- Conn methods -------------------------------------------------- // NewConn wraps a freshly-accepted websocket for a given hostID. func NewConn(hostID string, c *websocket.Conn) *Conn { return &Conn{HostID: hostID, c: c} } // Send writes an envelope as a JSON text message. Concurrent calls // are serialized; the underlying socket is not safe for parallel // writers. func (c *Conn) Send(ctx context.Context, env api.Envelope) error { c.writeMu.Lock() defer c.writeMu.Unlock() raw, err := json.Marshal(env) if err != nil { return fmt.Errorf("ws: marshal envelope: %w", err) } return c.c.Write(ctx, websocket.MessageText, raw) } // SendError writes an error envelope and closes the socket. Used by // the hello handshake when an agent is rejected. func (c *Conn) SendError(ctx context.Context, code api.ErrorCode, msg, helpURL string) { env, err := api.Marshal(api.MsgError, "", api.ErrorPayload{ Code: code, Message: msg, HelpURL: helpURL, }) if err == nil { writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() _ = c.Send(writeCtx, env) } _ = c.c.Close(websocket.StatusPolicyViolation, string(code)) } // Close shuts the socket down with a normal-closure status code. func (c *Conn) Close() error { return c.c.Close(websocket.StatusNormalClosure, "") } // Read pulls the next JSON envelope off the wire. The caller's // context controls cancellation and timeouts (e.g. read deadlines). func (c *Conn) Read(ctx context.Context) (api.Envelope, error) { mt, raw, err := c.c.Read(ctx) if err != nil { return api.Envelope{}, err } if mt != websocket.MessageText { return api.Envelope{}, errors.New("ws: expected text frame") } var env api.Envelope if err := json.Unmarshal(raw, &env); err != nil { return api.Envelope{}, fmt.Errorf("ws: unmarshal envelope: %w", err) } return env, nil } // ----- helpers ------------------------------------------------------- // LogValue emits a slog-friendly representation of a Conn. func (c *Conn) LogValue() slog.Value { return slog.GroupValue(slog.String("host_id", c.HostID)) }