diff --git a/internal/server/ws/handler.go b/internal/server/ws/handler.go index 48fb5fd..5706693 100644 --- a/internal/server/ws/handler.go +++ b/internal/server/ws/handler.go @@ -267,8 +267,34 @@ func dispatchAgentMessage(ctx context.Context, c *Conn, hostID string, env api.E deps.OnScheduleFire(ctx, hostID, c, p.ScheduleID, p.ScheduledAt) } - case api.MsgRepoStats, api.MsgCommandResult: - // TODO(P2): persist these projections. + case api.MsgRepoStats: + var p api.RepoStatsPayload + if err := env.UnmarshalPayload(&p); err != nil { + slog.Warn("ws: bad repo.stats payload", "host_id", hostID, "err", err) + break + } + patch := store.HostRepoStats{ + HostID: hostID, + TotalSizeBytes: p.TotalSizeBytes, + RawSizeBytes: p.RawSizeBytes, + UniqueFiles: p.UniqueFiles, + SnapshotCount: p.SnapshotCount, + LastCheckAt: p.LastCheckAt, + LastCheckStatus: p.LastCheckStatus, + LockPresent: p.LockPresent, + LastPruneAt: p.LastPruneAt, + LastPruneFreedBytes: p.LastPruneFreedBytes, + } + if err := deps.Store.UpsertHostRepoStats(ctx, hostID, patch); err != nil { + slog.Warn("ws: upsert host repo stats", "host_id", hostID, "err", err) + } else { + slog.Info("ws: repo stats refreshed", "host_id", hostID) + } + + case api.MsgCommandResult: + // TODO(P2): persist command.result acks for "did the agent + // accept the dispatch?" forensics. Currently the job lifecycle + // (job.started → job.finished) is sufficient signal. slog.Debug("ws msg not yet handled", "type", env.Type, "host_id", hostID) case api.MsgError: diff --git a/internal/server/ws/handler_test.go b/internal/server/ws/handler_test.go new file mode 100644 index 0000000..819a812 --- /dev/null +++ b/internal/server/ws/handler_test.go @@ -0,0 +1,135 @@ +package ws + +import ( + "context" + "path/filepath" + "testing" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/api" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// openWSTestStore opens an isolated file-backed db in t.TempDir. +func openWSTestStore(t *testing.T) *store.Store { + t.Helper() + dir := t.TempDir() + s, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = s.Close() }) + return s +} + +// seedHostWS inserts a minimal host row directly via the store's DB. +func seedHostWS(t *testing.T, s *store.Store, hostID string) { + t.Helper() + _, err := s.DB().Exec( + `INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`, + hostID, hostID, "linux", "amd64", "2026-01-01T00:00:00Z") + if err != nil { + t.Fatalf("seed host %q: %v", hostID, err) + } +} + +func int64ptrWS(v int64) *int64 { return &v } +func boolptrWS(v bool) *bool { return &v } + +func TestRepoStatsReportPersisted(t *testing.T) { + t.Parallel() + s := openWSTestStore(t) + ctx := context.Background() + + const hostID = "h-stats-ws" + seedHostWS(t, s, hostID) + + now := time.Now().UTC().Truncate(time.Second) + pruneAt := now.Add(-2 * time.Hour) + payload := api.RepoStatsPayload{ + TotalSizeBytes: int64ptrWS(1024), + RawSizeBytes: int64ptrWS(2048), + UniqueFiles: int64ptrWS(42), + SnapshotCount: int64ptrWS(7), + LastCheckAt: &now, + LastCheckStatus: "ok", + LockPresent: boolptrWS(false), + LastPruneAt: &pruneAt, + LastPruneFreedBytes: int64ptrWS(512), + } + env, err := api.Marshal(api.MsgRepoStats, "", payload) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + deps := HandlerDeps{Store: s} + dispatchAgentMessage(ctx, nil, hostID, env, deps) + + got, err := s.GetHostRepoStats(ctx, hostID) + if err != nil { + t.Fatalf("get host repo stats: %v", err) + } + if got.TotalSizeBytes == nil || *got.TotalSizeBytes != 1024 { + t.Errorf("TotalSizeBytes: got %v want 1024", got.TotalSizeBytes) + } + if got.RawSizeBytes == nil || *got.RawSizeBytes != 2048 { + t.Errorf("RawSizeBytes: got %v want 2048", got.RawSizeBytes) + } + if got.UniqueFiles == nil || *got.UniqueFiles != 42 { + t.Errorf("UniqueFiles: got %v want 42", got.UniqueFiles) + } + if got.SnapshotCount == nil || *got.SnapshotCount != 7 { + t.Errorf("SnapshotCount: got %v want 7", got.SnapshotCount) + } + if got.LastCheckAt == nil || !got.LastCheckAt.Equal(now) { + t.Errorf("LastCheckAt: got %v want %v", got.LastCheckAt, now) + } + if got.LastCheckStatus != "ok" { + t.Errorf("LastCheckStatus: got %q want %q", got.LastCheckStatus, "ok") + } + if got.LockPresent == nil || *got.LockPresent != false { + t.Errorf("LockPresent: got %v want false", got.LockPresent) + } + if got.LastPruneAt == nil || !got.LastPruneAt.Equal(pruneAt) { + t.Errorf("LastPruneAt: got %v want %v", got.LastPruneAt, pruneAt) + } + if got.LastPruneFreedBytes == nil || *got.LastPruneFreedBytes != 512 { + t.Errorf("LastPruneFreedBytes: got %v want 512", got.LastPruneFreedBytes) + } +} + +func TestRepoStatsReportPartialUpdate(t *testing.T) { + t.Parallel() + s := openWSTestStore(t) + ctx := context.Background() + + const hostID = "h-stats-partial" + seedHostWS(t, s, hostID) + + // Pre-seed: TotalSizeBytes = 100. + if err := s.UpsertHostRepoStats(ctx, hostID, store.HostRepoStats{ + TotalSizeBytes: int64ptrWS(100), + }); err != nil { + t.Fatalf("pre-seed upsert: %v", err) + } + + // Send a repo.stats payload that only sets LastCheckStatus. + env, err := api.Marshal(api.MsgRepoStats, "", api.RepoStatsPayload{ + LastCheckStatus: "ok", + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + dispatchAgentMessage(ctx, nil, hostID, env, HandlerDeps{Store: s}) + + got, err := s.GetHostRepoStats(ctx, hostID) + if err != nil { + t.Fatalf("get: %v", err) + } + if got.TotalSizeBytes == nil || *got.TotalSizeBytes != 100 { + t.Errorf("TotalSizeBytes lost: got %v want 100", got.TotalSizeBytes) + } + if got.LastCheckStatus != "ok" { + t.Errorf("LastCheckStatus: got %q want ok", got.LastCheckStatus) + } +}