diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 5cac43e..9d1d15b 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -274,6 +274,17 @@ func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.S slog.Info("ws agent: command.cancel for unknown job (already finished?)", "job_id", p.JobID) } + case api.MsgTreeList: + // Synchronous RPC for the restore wizard's tree browser. The + // server has serialized access; we just run restic ls and reply + // with the same envelope ID. Run in a goroutine so the WS read + // loop keeps draining. + var p api.TreeListRequestPayload + if err := env.UnmarshalPayload(&p); err != nil { + return fmt.Errorf("tree.list: %w", err) + } + go d.handleTreeList(ctx, env.ID, p, tx) + case api.MsgScheduleSet: var p api.ScheduleSetPayload if err := env.UnmarshalPayload(&p); err != nil { @@ -381,6 +392,72 @@ func (d *dispatcher) handle(ctx context.Context, env api.Envelope, tx wsclient.S return nil } +// handleTreeList runs `restic ls --json ` and ships +// the matching tree.list.result envelope back, correlated by the +// request envelope's ID. Errors (missing creds, restic failure) +// surface in the result's Error field rather than as transport-level +// failures so the server-side waiter can render a sensible message. +func (d *dispatcher) handleTreeList(ctx context.Context, reqID string, p api.TreeListRequestPayload, tx wsclient.Sender) { + reply := func(result api.TreeListResultPayload) { + result.SnapshotID = p.SnapshotID + result.Path = p.Path + env, err := api.Marshal(api.MsgTreeListResult, reqID, result) + if err != nil { + slog.Warn("ws agent: marshal tree.list.result", "err", err) + return + } + _ = tx.Send(env) + } + + if d.resticBin == "" { + reply(api.TreeListResultPayload{Error: "restic binary not located on this agent"}) + return + } + creds, err := d.secrets.Load() + if err != nil { + reply(api.TreeListResultPayload{Error: "load credentials: " + err.Error()}) + return + } + if creds.Empty() { + reply(api.TreeListResultPayload{Error: "repo credentials not configured"}) + return + } + + d.bwMu.Lock() + upKBps, downKBps := d.bwUpKBps, d.bwDownKBps + d.bwMu.Unlock() + + env := restic.Env{ + Bin: d.resticBin, + RepoURL: creds.URL, + RepoUsername: creds.Username, + RepoPassword: creds.Password, + LimitUploadKBps: upKBps, + LimitDownloadKBps: downKBps, + } + + // 60s ceiling matches snapshots/stats — restic ls on a single + // directory is normally sub-second; if the repo is unreachable we + // want to surface the failure rather than block the wizard. + listCtx, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + + entries, err := env.ListTreeChildren(listCtx, p.SnapshotID, p.Path) + if err != nil { + reply(api.TreeListResultPayload{Error: err.Error()}) + return + } + apiEntries := make([]api.TreeListEntry, 0, len(entries)) + for _, e := range entries { + apiEntries = append(apiEntries, api.TreeListEntry{ + Name: e.Name, + Type: e.Type, + Size: e.Size, + }) + } + reply(api.TreeListResultPayload{Entries: apiEntries}) +} + // runJob spawns a runner for one job. We launch a goroutine so the // WS read loop keeps draining messages while restic chugs along. func (d *dispatcher) runJob(ctx context.Context, p api.CommandRunPayload, tx wsclient.Sender) error { diff --git a/internal/api/messages.go b/internal/api/messages.go index ce43bc3..1ec64fb 100644 --- a/internal/api/messages.go +++ b/internal/api/messages.go @@ -337,3 +337,37 @@ type AgentUpdateAvailablePayload struct { PackageURL string `json:"package_url"` // apt repo / choco source Changelog string `json:"changelog,omitempty"` } + +// TreeListRequestPayload is the body of a tree.list RPC. Used by the +// restore wizard to lazy-load directory contents from a snapshot. +// +// The exchange is synchronous: the server marshals MsgTreeList with a +// fresh Envelope.ID, sends to the agent, blocks on a channel keyed by +// that ID. The agent runs `restic ls --json `, +// emits direct children, and replies with MsgTreeListResult carrying +// the same ID. The server-side handler matches on ID and forwards to +// the waiting channel. See internal/server/ws/rpc.go for the helper. +type TreeListRequestPayload struct { + SnapshotID string `json:"snapshot_id"` + Path string `json:"path"` // absolute path inside the snapshot, "/" for root +} + +// TreeListEntry is one direct child returned by a tree.list call. +// Type is "dir" | "file" | "symlink"; size is best-effort (zero on +// directories and symlinks). +type TreeListEntry struct { + Name string `json:"name"` + Type string `json:"type"` + Size int64 `json:"size,omitempty"` +} + +// TreeListResultPayload is the reply to a tree.list. Error is set +// when the agent couldn't fulfill the request (missing snapshot, +// path doesn't exist, restic invocation failed); Entries is empty in +// that case. A successful empty directory has Error="" + nil Entries. +type TreeListResultPayload struct { + SnapshotID string `json:"snapshot_id"` + Path string `json:"path"` + Entries []TreeListEntry `json:"entries,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/internal/api/wire.go b/internal/api/wire.go index df646a5..a52a58b 100644 --- a/internal/api/wire.go +++ b/internal/api/wire.go @@ -12,18 +12,19 @@ type MessageType string // Agent → server message types. const ( - MsgHello MessageType = "hello" - MsgHeartbeat MessageType = "heartbeat" - MsgJobStarted MessageType = "job.started" - MsgJobProgress MessageType = "job.progress" - MsgJobFinished MessageType = "job.finished" - MsgSnapshotsRpt MessageType = "snapshots.report" - MsgRepoStats MessageType = "repo.stats" - MsgLogStream MessageType = "log.stream" - MsgScheduleAck MessageType = "schedule.ack" - MsgScheduleFire MessageType = "schedule.fire" // agent: a local cron entry fired, please dispatch a job - MsgCommandResult MessageType = "command.result" // ack for command.run - MsgError MessageType = "error" + MsgHello MessageType = "hello" + MsgHeartbeat MessageType = "heartbeat" + MsgJobStarted MessageType = "job.started" + MsgJobProgress MessageType = "job.progress" + MsgJobFinished MessageType = "job.finished" + MsgSnapshotsRpt MessageType = "snapshots.report" + MsgRepoStats MessageType = "repo.stats" + MsgLogStream MessageType = "log.stream" + MsgScheduleAck MessageType = "schedule.ack" + MsgScheduleFire MessageType = "schedule.fire" // agent: a local cron entry fired, please dispatch a job + MsgCommandResult MessageType = "command.result" // ack for command.run + MsgTreeListResult MessageType = "tree.list.result" // reply to a server-driven tree.list + MsgError MessageType = "error" ) // Server → agent message types. @@ -33,6 +34,7 @@ const ( MsgScheduleSet MessageType = "schedule.set" MsgConfigUpdate MessageType = "config.update" MsgAgentUpdateAvail MessageType = "agent.update.available" + MsgTreeList MessageType = "tree.list" // sync RPC: list a snapshot's children ) // Envelope is the framing for every WS message in either direction. diff --git a/internal/restic/ls.go b/internal/restic/ls.go new file mode 100644 index 0000000..5625238 --- /dev/null +++ b/internal/restic/ls.go @@ -0,0 +1,140 @@ +package restic + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os/exec" + "path" + "strings" +) + +// LsEntry is one node from `restic ls --json`. Restic emits these as +// line-delimited JSON; we keep only the fields the restore wizard +// needs. +type LsEntry struct { + Name string `json:"name"` + Type string `json:"type"` + Path string `json:"path"` + Size int64 `json:"size,omitempty"` + Struct string `json:"struct_type,omitempty"` +} + +// ListTreeChildren runs `restic ls --json ` and +// returns only the direct children of dirPath. Restic ls is recursive +// by default, so we filter post-hoc — for a typical interactive +// drill-down ("expand /etc/nginx") the subtree is small (a few KB of +// JSON); for huge subtrees this is suboptimal but correct. +// +// The first emitted line is restic's "snapshot" preamble (struct_type +// = "snapshot") which we discard. Subsequent lines are nodes; we +// match on path equal to dirPath + "/" + name (with normalization so +// trailing slashes don't break the comparison). +// +// dirPath="" or "/" lists the snapshot root. +func (e Env) ListTreeChildren(ctx context.Context, snapshotID, dirPath string) ([]LsEntry, error) { + if snapshotID == "" { + return nil, fmt.Errorf("restic ls: snapshot id required") + } + parent := normalizeTreePath(dirPath) + + args := []string{"ls", "--json", snapshotID} + if parent != "/" { + args = append(args, parent) + } + cmd := e.resticCmd(ctx, args...) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("restic ls: stdout pipe: %w", err) + } + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("restic ls: start: %w", err) + } + + out, parseErr := parseLsChildren(stdout, parent) + + werr := cmd.Wait() + if werr != nil { + var ee *exec.ExitError + if errors.As(werr, &ee) { + return nil, fmt.Errorf("restic ls: exit %d: %s", + ee.ExitCode(), strings.TrimSpace(stderr.String())) + } + return nil, fmt.Errorf("restic ls: %w", werr) + } + if parseErr != nil { + return nil, parseErr + } + return out, nil +} + +// parseLsChildren reads line-delimited JSON from r and returns nodes +// whose Path is a direct child of parent. Exposed for testing. +func parseLsChildren(r io.Reader, parent string) ([]LsEntry, error) { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + var out []LsEntry + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var entry LsEntry + if err := json.Unmarshal(line, &entry); err != nil { + return nil, fmt.Errorf("restic ls: parse line: %w", err) + } + // Skip the snapshot preamble and any future struct_type + // entries we don't care about. + if entry.Struct == "snapshot" || entry.Path == "" { + continue + } + if isDirectChild(entry.Path, parent) { + out = append(out, entry) + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("restic ls: read output: %w", err) + } + return out, nil +} + +// normalizeTreePath turns "" / "/" / "/etc/" / "etc" all into a +// canonical absolute form with a leading slash and no trailing slash +// (except the root, which is "/" alone). +func normalizeTreePath(p string) string { + p = strings.TrimSpace(p) + if p == "" || p == "/" { + return "/" + } + if !strings.HasPrefix(p, "/") { + p = "/" + p + } + cleaned := path.Clean(p) + return cleaned +} + +// isDirectChild reports whether childPath is a direct child of parent. +// "/etc/nginx" is a direct child of "/etc"; "/etc/nginx/conf" is not. +// "/etc" is a direct child of "/". +func isDirectChild(childPath, parent string) bool { + cp := normalizeTreePath(childPath) + pp := normalizeTreePath(parent) + if pp == "/" { + // Direct children of root: exactly one slash-delimited segment. + return cp != "/" && strings.Count(cp, "/") == 1 + } + // Must start with parent + "/" and have no further slashes. + prefix := pp + "/" + if !strings.HasPrefix(cp, prefix) { + return false + } + rest := cp[len(prefix):] + return rest != "" && !strings.Contains(rest, "/") +} diff --git a/internal/restic/ls_test.go b/internal/restic/ls_test.go new file mode 100644 index 0000000..4688383 --- /dev/null +++ b/internal/restic/ls_test.go @@ -0,0 +1,123 @@ +package restic + +import ( + "strings" + "testing" +) + +// realistic restic ls --json output sample. First line is the +// snapshot preamble, subsequent lines are nodes. Trimmed to a few +// entries that exercise depth filtering. +const sampleLsOutput = `{"struct_type":"snapshot","time":"2026-05-04T09:14:00Z","id":"f3a7b2c1"} +{"name":"etc","type":"dir","path":"/etc","permissions":"drwxr-xr-x","struct_type":"node"} +{"name":"nginx","type":"dir","path":"/etc/nginx","permissions":"drwxr-xr-x","struct_type":"node"} +{"name":"nginx.conf","type":"file","path":"/etc/nginx/nginx.conf","size":2400,"struct_type":"node"} +{"name":"sites-available","type":"dir","path":"/etc/nginx/sites-available","struct_type":"node"} +{"name":"alfa.conf","type":"file","path":"/etc/nginx/sites-available/alfa.conf","size":3100,"struct_type":"node"} +{"name":"default.conf","type":"file","path":"/etc/nginx/sites-available/default.conf","size":2900,"struct_type":"node"} +` + +func TestParseLsChildrenAtRoot(t *testing.T) { + t.Parallel() + entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/") + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(entries) != 1 { + t.Fatalf("entries: got %d (%+v), want 1", len(entries), entries) + } + if entries[0].Name != "etc" || entries[0].Path != "/etc" || entries[0].Type != "dir" { + t.Fatalf("entry: %+v", entries[0]) + } +} + +func TestParseLsChildrenAtEtc(t *testing.T) { + t.Parallel() + entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc") + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(entries) != 1 { + t.Fatalf("entries: got %d, want 1 (just nginx, not nested children)", len(entries)) + } + if entries[0].Name != "nginx" { + t.Fatalf("entry: %+v", entries[0]) + } +} + +func TestParseLsChildrenAtNginx(t *testing.T) { + t.Parallel() + entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc/nginx") + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(entries) != 2 { + t.Fatalf("entries: got %d (%+v), want 2 (nginx.conf + sites-available, not nested)", + len(entries), entries) + } + gotNames := []string{entries[0].Name, entries[1].Name} + want := map[string]bool{"nginx.conf": true, "sites-available": true} + for _, n := range gotNames { + if !want[n] { + t.Errorf("unexpected name %q in result", n) + } + } +} + +func TestParseLsChildrenAtSitesAvailable(t *testing.T) { + t.Parallel() + entries, err := parseLsChildren(strings.NewReader(sampleLsOutput), "/etc/nginx/sites-available") + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(entries) != 2 { + t.Fatalf("entries: got %d, want 2", len(entries)) + } + for _, e := range entries { + if e.Type != "file" { + t.Errorf("expected file type, got %q on %q", e.Type, e.Name) + } + } +} + +func TestNormalizeTreePath(t *testing.T) { + t.Parallel() + cases := []struct{ in, want string }{ + {"", "/"}, + {"/", "/"}, + {"/etc", "/etc"}, + {"/etc/", "/etc"}, + {"etc/nginx", "/etc/nginx"}, + {"/etc//nginx", "/etc/nginx"}, + {"/etc/./nginx", "/etc/nginx"}, + } + for _, c := range cases { + got := normalizeTreePath(c.in) + if got != c.want { + t.Errorf("normalizeTreePath(%q): got %q, want %q", c.in, got, c.want) + } + } +} + +func TestIsDirectChild(t *testing.T) { + t.Parallel() + cases := []struct { + child, parent string + want bool + }{ + {"/etc", "/", true}, + {"/etc/nginx", "/", false}, + {"/etc/nginx", "/etc", true}, + {"/etc/nginx/conf", "/etc", false}, + {"/etc/nginx/conf", "/etc/nginx", true}, + {"/etc", "/etc", false}, + {"/etcc", "/etc", false}, // prefix match guard + } + for _, c := range cases { + got := isDirectChild(c.child, c.parent) + if got != c.want { + t.Errorf("isDirectChild(%q, %q): got %v, want %v", + c.child, c.parent, got, c.want) + } + } +} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index b808cfc..a21aedd 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -58,6 +58,11 @@ type Server struct { // pending_id so the accept/reject handlers can push the bearer // or close cleanly (P2-18b). pendingHub *pendingHub + + // treeCache holds per-wizard-session listings of snapshot + // directories (P3-X2). Pre-allocated in New so the lazy-init + // race is impossible. + treeCache *treeCache } // New builds a configured but not-yet-started server. @@ -81,6 +86,7 @@ func New(deps Deps) *Server { drainLocks: make(map[string]*sync.Mutex), announceRL: newAnnounceLimiter(), pendingHub: newPendingHub(), + treeCache: newTreeCache(), } s.routes(r) diff --git a/internal/server/http/tree_cache.go b/internal/server/http/tree_cache.go new file mode 100644 index 0000000..9d7c077 --- /dev/null +++ b/internal/server/http/tree_cache.go @@ -0,0 +1,112 @@ +package http + +import ( + "context" + "sync" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// treeCacheTTL is how long a per-session cached directory listing +// stays valid. The whole point of the cache is to make re-expanding +// nodes within the same wizard session snappy; 30 minutes covers a +// generous wizard interaction window without holding stale data +// indefinitely. +const treeCacheTTL = 30 * time.Minute + +// treeCacheKey identifies one cached listing. session_id scopes +// entries to a single browser session so two operators don't share +// view state; snapshot_id + path identify the directory inside the +// snapshot. +type treeCacheKey struct { + SessionID string + HostID string + SnapshotID string + Path string +} + +type treeCacheEntry struct { + Result api.TreeListResultPayload + ExpiresAt time.Time +} + +// treeCache is a per-process map of synchronously fetched directory +// listings. Concurrency is light (a few entries per active wizard +// session) so a single mutex is fine. +type treeCache struct { + mu sync.Mutex + entries map[treeCacheKey]treeCacheEntry +} + +func newTreeCache() *treeCache { + return &treeCache{entries: make(map[treeCacheKey]treeCacheEntry)} +} + +// Get returns a cached entry if one exists and hasn't expired. +func (c *treeCache) Get(k treeCacheKey, now time.Time) (api.TreeListResultPayload, bool) { + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.entries[k] + if !ok { + return api.TreeListResultPayload{}, false + } + if now.After(e.ExpiresAt) { + delete(c.entries, k) + return api.TreeListResultPayload{}, false + } + return e.Result, true +} + +// Put records a fresh listing under k. Caller is responsible for +// having validated the result first (Error == ""). +func (c *treeCache) Put(k treeCacheKey, result api.TreeListResultPayload, now time.Time) { + c.mu.Lock() + c.entries[k] = treeCacheEntry{ + Result: result, + ExpiresAt: now.Add(treeCacheTTL), + } + c.mu.Unlock() +} + +// Sweep deletes expired entries. Called opportunistically from the +// wizard handler — no separate goroutine needed; cache size is small. +func (c *treeCache) Sweep(now time.Time) { + c.mu.Lock() + for k, e := range c.entries { + if now.After(e.ExpiresAt) { + delete(c.entries, k) + } + } + c.mu.Unlock() +} + +// fetchTreeWithCache returns a directory listing — cache hit, or a +// synchronous tree.list RPC against the agent on miss. On agent error +// (not transport error), the result is returned as-is with Error set +// rather than cached, so a transient failure doesn't poison subsequent +// requests for the same path. +// +//nolint:unused // wired in by the wizard handler in the next slice +func (s *Server) fetchTreeWithCache(ctx context.Context, sessionID, hostID, snapshotID, path string) (api.TreeListResultPayload, error) { + now := time.Now() + k := treeCacheKey{SessionID: sessionID, HostID: hostID, SnapshotID: snapshotID, Path: path} + if cached, ok := s.treeCache.Get(k, now); ok { + return cached, nil + } + + reply, err := s.deps.Hub.SendRPC(ctx, hostID, api.MsgTreeList, + api.TreeListRequestPayload{SnapshotID: snapshotID, Path: path}, + 30*time.Second) + if err != nil { + return api.TreeListResultPayload{}, err + } + var result api.TreeListResultPayload + if perr := reply.UnmarshalPayload(&result); perr != nil { + return api.TreeListResultPayload{}, perr + } + if result.Error == "" { + s.treeCache.Put(k, result, now) + } + return result, nil +} diff --git a/internal/server/http/tree_rpc_test.go b/internal/server/http/tree_rpc_test.go new file mode 100644 index 0000000..e627235 --- /dev/null +++ b/internal/server/http/tree_rpc_test.go @@ -0,0 +1,146 @@ +// tree_rpc_test.go — full round-trip test for the tree.list synchronous +// RPC (P3-X2). A fake agent reads the inbound tree.list, replies with a +// canned tree.list.result, and we assert the server's SendRPC returned +// the expected payload. +package http + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/coder/websocket" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +func TestSendRPCTreeListRoundTrip(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "rpc-host") + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "rpc-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Fake agent: read inbound envelopes, mirror tree.list with a + // canned result. Other inbound envelopes (config.update etc) are + // already drained above. + done := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + for { + mt, raw, err := c.Read(ctx) + if err != nil { + done <- err + return + } + if mt != websocket.MessageText { + continue + } + var env api.Envelope + if err := json.Unmarshal(raw, &env); err != nil { + done <- err + return + } + if env.Type != api.MsgTreeList { + continue + } + var req api.TreeListRequestPayload + if err := env.UnmarshalPayload(&req); err != nil { + done <- err + return + } + result := api.TreeListResultPayload{ + SnapshotID: req.SnapshotID, + Path: req.Path, + Entries: []api.TreeListEntry{ + {Name: "etc", Type: "dir"}, + {Name: "var", Type: "dir"}, + }, + } + out, err := api.Marshal(api.MsgTreeListResult, env.ID, result) + if err != nil { + done <- err + return + } + rawOut, _ := json.Marshal(out) + if err := c.Write(ctx, websocket.MessageText, rawOut); err != nil { + done <- err + return + } + done <- nil + return + } + }() + + // Server-side SendRPC. + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + reply, err := srv.deps.Hub.SendRPC(ctx, hostID, api.MsgTreeList, + api.TreeListRequestPayload{SnapshotID: "f3a7b2c1", Path: "/"}, + 3*time.Second) + if err != nil { + t.Fatalf("SendRPC: %v", err) + } + if reply.Type != api.MsgTreeListResult { + t.Fatalf("reply type: got %q want %q", reply.Type, api.MsgTreeListResult) + } + var result api.TreeListResultPayload + if err := reply.UnmarshalPayload(&result); err != nil { + t.Fatalf("unmarshal reply: %v", err) + } + if result.SnapshotID != "f3a7b2c1" || result.Path != "/" { + t.Fatalf("payload: got %+v", result) + } + if len(result.Entries) != 2 || result.Entries[0].Name != "etc" { + t.Fatalf("entries: %+v", result.Entries) + } + + // Make sure the fake agent didn't error out. + select { + case err := <-done: + if err != nil { + t.Fatalf("fake agent: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("fake agent didn't finish") + } +} + +// TestSendRPCTimeoutNoReply: SendRPC times out cleanly when the agent +// never replies; the registry entry is released so a stray late reply +// wouldn't deadlock anything. +func TestSendRPCTimeoutNoReply(t *testing.T) { + t.Parallel() + srv, ts, st := rawTestServer(t) + hostID, token := enrolHostForWS(t, srv, st, "rpc-timeout-host") + c := agentDial(t, srv, ts, hostID, token) + sendHello(t, c, "rpc-timeout-host") + _ = drainUntil(t, c, api.MsgScheduleSet) + + // Fake agent reads but never replies. + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + for { + if _, _, err := c.Read(ctx); err != nil { + return + } + } + }() + + ctx := context.Background() + t0 := time.Now() + _, err := srv.deps.Hub.SendRPC(ctx, hostID, api.MsgTreeList, + api.TreeListRequestPayload{SnapshotID: "x", Path: "/"}, + 300*time.Millisecond) + if err == nil { + t.Fatal("expected timeout error") + } + elapsed := time.Since(t0) + if elapsed < 250*time.Millisecond || elapsed > 2*time.Second { + t.Fatalf("timeout took %s, expected ~300ms", elapsed) + } +} diff --git a/internal/server/ws/handler.go b/internal/server/ws/handler.go index 5706693..27bed4f 100644 --- a/internal/server/ws/handler.go +++ b/internal/server/ws/handler.go @@ -297,6 +297,20 @@ func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.E // (job.started → job.finished) is sufficient signal. slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID) + case api.MsgTreeListResult: + // Reply to a synchronous tree.list RPC. Route to the waiter + // registered against the request envelope's ID; if none is + // registered the caller already gave up (ctx expired) — drop + // the stray reply quietly. + if env.ID == "" { + slog.Warn("ws: tree.list.result missing envelope ID", "host_id", hostID) + break + } + if !deps.Hub.rpcs.resolve(env.ID, env) { + slog.Debug("ws: tree.list.result with no waiter (timeout?)", + "id", env.ID, "host_id", hostID) + } + case api.MsgError: var ep api.ErrorPayload _ = env.UnmarshalPayload(&ep) diff --git a/internal/server/ws/hub.go b/internal/server/ws/hub.go index 8ad732f..e69cf9b 100644 --- a/internal/server/ws/hub.go +++ b/internal/server/ws/hub.go @@ -21,6 +21,11 @@ import ( type Hub struct { mu sync.RWMutex conns map[string]*Conn // hostID → conn + + // rpcs tracks in-flight synchronous RPC calls (e.g. tree.list). + // See rpc.go for details. Lazy-initialized via the registry's + // own register() so callers don't have to juggle a constructor. + rpcs rpcRegistry } // NewHub returns an empty hub. diff --git a/internal/server/ws/rpc.go b/internal/server/ws/rpc.go new file mode 100644 index 0000000..e4da8c3 --- /dev/null +++ b/internal/server/ws/rpc.go @@ -0,0 +1,112 @@ +package ws + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// rpcRegistry holds in-flight synchronous RPC calls. SendRPC registers +// a channel keyed by the request envelope's ID; the WS read loop's +// dispatcher routes incoming reply envelopes to the matching channel +// when their type is one of the known reply types (currently just +// tree.list.result). +// +// A single global registry keyed by envelope ID is fine because IDs +// are ULIDs — globally unique without coordinating across hubs. +type rpcRegistry struct { + mu sync.Mutex + pending map[string]chan api.Envelope +} + +// register reserves a channel for the given request ID. The channel +// is buffered (cap 1) so a slow waiter doesn't block the read loop's +// dispatcher when the reply lands. +func (r *rpcRegistry) register(id string) chan api.Envelope { + ch := make(chan api.Envelope, 1) + r.mu.Lock() + if r.pending == nil { + r.pending = make(map[string]chan api.Envelope) + } + r.pending[id] = ch + r.mu.Unlock() + return ch +} + +// resolve delivers an envelope to its waiter and removes the entry. +// Returns whether a waiter was actually present (the dispatcher uses +// this to decide whether to log a stray-reply warning). +func (r *rpcRegistry) resolve(id string, env api.Envelope) bool { + r.mu.Lock() + ch, ok := r.pending[id] + if ok { + delete(r.pending, id) + } + r.mu.Unlock() + if !ok { + return false + } + // Buffered chan cap 1 — non-blocking send. The waiter goroutine + // owns the receive side so this is the only sender. + ch <- env + close(ch) + return true +} + +// release abandons the entry without delivering a value. Used when +// the caller's context expires before a reply arrives — the next +// stray reply (if any) will hit the no-waiter case in resolve and +// just be dropped. +func (r *rpcRegistry) release(id string) { + r.mu.Lock() + delete(r.pending, id) + r.mu.Unlock() +} + +// SendRPC sends a request envelope to the host and blocks until a +// matching reply lands or the context expires. The hub picks a fresh +// envelope ID, marshals the payload, registers a waiter, and sends. +// +// timeout caps the wait; a too-aggressive value relative to the +// expected restic-side latency will leak the registry entry until the +// reply finally arrives (which is then silently dropped). The default +// callers use is 30s, which covers a slow network round-trip plus a +// restic ls invocation against a remote rest-server. +// +// If the host disconnects mid-flight, the read loop ends and no reply +// will ever come — the caller's ctx.Done()/timeout is the only path +// out. We could pre-fail by tracking conn lifetime, but the bound +// keeps the code simple and the worst case is a 30s wait. +func (h *Hub) SendRPC(ctx context.Context, hostID string, reqType api.MessageType, payload any, timeout time.Duration) (api.Envelope, error) { + if timeout <= 0 { + timeout = 30 * time.Second + } + id := ulid.Make().String() + env, err := api.Marshal(reqType, id, payload) + if err != nil { + return api.Envelope{}, err + } + + ch := h.rpcs.register(id) + + if err := h.Send(ctx, hostID, env); err != nil { + h.rpcs.release(id) + return api.Envelope{}, err + } + + select { + case reply := <-ch: + return reply, nil + case <-ctx.Done(): + h.rpcs.release(id) + return api.Envelope{}, ctx.Err() + case <-time.After(timeout): + h.rpcs.release(id) + return api.Envelope{}, errors.New("ws rpc: timed out waiting for reply") + } +} diff --git a/internal/server/ws/rpc_test.go b/internal/server/ws/rpc_test.go new file mode 100644 index 0000000..7e9e290 --- /dev/null +++ b/internal/server/ws/rpc_test.go @@ -0,0 +1,122 @@ +package ws + +import ( + "context" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/oklog/ulid/v2" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" +) + +// TestRPCRegistryRoundTrip: register a waiter, resolve it, get the +// envelope back. Cover the no-waiter and double-resolve cases too. +func TestRPCRegistryRoundTrip(t *testing.T) { + t.Parallel() + var r rpcRegistry + id := ulid.Make().String() + ch := r.register(id) + + want := api.Envelope{Type: api.MsgTreeListResult, ID: id, Payload: json.RawMessage(`{"path":"/"}`)} + if !r.resolve(id, want) { + t.Fatal("resolve: returned false for registered id") + } + got := <-ch + if got.ID != id { + t.Fatalf("id mismatch: got %q want %q", got.ID, id) + } + + // A second resolve for the same id has no waiter and should not panic. + if r.resolve(id, want) { + t.Fatal("resolve: returned true for already-resolved id") + } +} + +// TestRPCRegistryRelease: release abandons the waiter; a subsequent +// resolve is a no-op (no goroutine leak, no panic). +func TestRPCRegistryRelease(t *testing.T) { + t.Parallel() + var r rpcRegistry + id := ulid.Make().String() + _ = r.register(id) + r.release(id) + if r.resolve(id, api.Envelope{ID: id}) { + t.Fatal("resolve after release: should be no-op") + } +} + +// TestRPCRegistryConcurrent: many waiters in flight concurrently get +// only their own reply. This catches buggy keying/locking. +func TestRPCRegistryConcurrent(t *testing.T) { + t.Parallel() + var r rpcRegistry + const n = 64 + + ids := make([]string, n) + chs := make([]chan api.Envelope, n) + for i := 0; i < n; i++ { + ids[i] = ulid.Make().String() + chs[i] = r.register(ids[i]) + } + + // Resolve in random-ish order from many goroutines. + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r.resolve(ids[idx], api.Envelope{ID: ids[idx], Type: api.MsgTreeListResult}) + }(i) + } + wg.Wait() + + for i := 0; i < n; i++ { + select { + case got := <-chs[i]: + if got.ID != ids[i] { + t.Fatalf("waiter %d: got id %q want %q", i, got.ID, ids[i]) + } + case <-time.After(2 * time.Second): + t.Fatalf("waiter %d: timed out", i) + } + } +} + +// TestSendRPCContextCancelReleases ensures that canceling the caller's +// ctx releases the registry entry so a stray late reply is harmlessly +// dropped. Skips if the hub isn't reachable for direct access — this +// is purely a unit test on the registry path inside SendRPC. +func TestSendRPCContextCancelReleases(t *testing.T) { + t.Parallel() + h := NewHub() + + // No host registered, so Hub.Send returns "host offline" and + // SendRPC bails without ever waiting. We test the timeout/ctx + // path by going through register() directly. + id := ulid.Make().String() + ch := h.rpcs.register(id) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + // Simulate the SendRPC select: ctx wins. + select { + case <-ch: + t.Fatal("unexpected reply") + case <-ctx.Done(): + h.rpcs.release(id) + } + + // Now a late reply should not block (ch is still open but no + // receiver — buffered size 1 absorbs it). + resolved := h.rpcs.resolve(id, api.Envelope{ID: id}) + if resolved { + t.Fatal("resolve after release should return false") + } +}