package ws import ( "context" "errors" "sync" "time" "github.com/oklog/ulid/v2" "gitea.dcglab.co.uk/steve/restic-manager/internal/api" ) // rpcRegistry holds in-flight synchronous RPC calls. SendRPC registers // a channel keyed by the request envelope's ID; the WS read loop's // dispatcher routes incoming reply envelopes to the matching channel // when their type is one of the known reply types (currently just // tree.list.result). // // A single global registry keyed by envelope ID is fine because IDs // are ULIDs — globally unique without coordinating across hubs. type rpcRegistry struct { mu sync.Mutex pending map[string]chan api.Envelope } // register reserves a channel for the given request ID. The channel // is buffered (cap 1) so a slow waiter doesn't block the read loop's // dispatcher when the reply lands. func (r *rpcRegistry) register(id string) chan api.Envelope { ch := make(chan api.Envelope, 1) r.mu.Lock() if r.pending == nil { r.pending = make(map[string]chan api.Envelope) } r.pending[id] = ch r.mu.Unlock() return ch } // resolve delivers an envelope to its waiter and removes the entry. // Returns whether a waiter was actually present (the dispatcher uses // this to decide whether to log a stray-reply warning). func (r *rpcRegistry) resolve(id string, env api.Envelope) bool { r.mu.Lock() ch, ok := r.pending[id] if ok { delete(r.pending, id) } r.mu.Unlock() if !ok { return false } // Buffered chan cap 1 — non-blocking send. The waiter goroutine // owns the receive side so this is the only sender. ch <- env close(ch) return true } // release abandons the entry without delivering a value. Used when // the caller's context expires before a reply arrives — the next // stray reply (if any) will hit the no-waiter case in resolve and // just be dropped. func (r *rpcRegistry) release(id string) { r.mu.Lock() delete(r.pending, id) r.mu.Unlock() } // SendRPC sends a request envelope to the host and blocks until a // matching reply lands or the context expires. The hub picks a fresh // envelope ID, marshals the payload, registers a waiter, and sends. // // timeout caps the wait; a too-aggressive value relative to the // expected restic-side latency will leak the registry entry until the // reply finally arrives (which is then silently dropped). The default // callers use is 30s, which covers a slow network round-trip plus a // restic ls invocation against a remote rest-server. // // If the host disconnects mid-flight, the read loop ends and no reply // will ever come — the caller's ctx.Done()/timeout is the only path // out. We could pre-fail by tracking conn lifetime, but the bound // keeps the code simple and the worst case is a 30s wait. func (h *Hub) SendRPC(ctx context.Context, hostID string, reqType api.MessageType, payload any, timeout time.Duration) (api.Envelope, error) { if timeout <= 0 { timeout = 30 * time.Second } id := ulid.Make().String() env, err := api.Marshal(reqType, id, payload) if err != nil { return api.Envelope{}, err } ch := h.rpcs.register(id) if err := h.Send(ctx, hostID, env); err != nil { h.rpcs.release(id) return api.Envelope{}, err } select { case reply := <-ch: return reply, nil case <-ctx.Done(): h.rpcs.release(id) return api.Envelope{}, ctx.Err() case <-time.After(timeout): h.rpcs.release(id) return api.Envelope{}, errors.New("ws rpc: timed out waiting for reply") } }