From 16e71a07086b34e32e1e2fff3c28a2da27a23738 Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Mon, 4 May 2026 19:44:31 +0100 Subject: [PATCH] notification: Hub fan-out + log writer --- internal/notification/hub.go | 187 ++++++++++++++++++++++++++++++ internal/notification/hub_test.go | 99 ++++++++++++++++ 2 files changed, 286 insertions(+) create mode 100644 internal/notification/hub.go create mode 100644 internal/notification/hub_test.go diff --git a/internal/notification/hub.go b/internal/notification/hub.go new file mode 100644 index 0000000..337b7f4 --- /dev/null +++ b/internal/notification/hub.go @@ -0,0 +1,187 @@ +package notification + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "log/slog" + "sync" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +// Hub fans Payload events out to every enabled channel and persists +// the result to notification_log. One Hub per process; thread-safe. +type Hub struct { + store *store.Store + aead *crypto.AEAD + baseURL string // e.g. https://restic-manager.example + msgIDDomain string // hostname extracted from baseURL for SMTP Message-ID +} + +// NewHub constructs a Hub. baseURL is the public root of the server +// (used to build /alerts/ links and the SMTP Message-ID domain). +func NewHub(st *store.Store, aead *crypto.AEAD, baseURL string) *Hub { + return &Hub{ + store: st, + aead: aead, + baseURL: baseURL, + msgIDDomain: extractDomain(baseURL), + } +} + +// Dispatch fans out to every enabled channel. Best-effort — failures +// are logged to notification_log but do not propagate to the caller. +// Each channel runs in its own goroutine; Dispatch returns only when +// all goroutines have settled, so the caller can block briefly for +// the test-button case. +func (h *Hub) Dispatch(ctx context.Context, p Payload) { + chans, err := h.store.ListEnabledNotificationChannels(ctx) + if err != nil { + slog.Error("notification: list channels", "err", err) + return + } + // Stamp the alert link if the caller left it empty. + if p.Link == "" { + p.Link = h.baseURL + "/alerts/" + p.AlertID + } + + var wg sync.WaitGroup + for _, c := range chans { + wg.Add(1) + go func(c store.NotificationChannel) { + defer wg.Done() + h.send(ctx, c, p) + }(c) + } + wg.Wait() +} + +// DispatchOne fires a single channel — used by the "Send test +// notification" button. Returns the log entry that was persisted so +// the handler can render the result inline. +func (h *Hub) DispatchOne(ctx context.Context, channelID string, p Payload) (store.NotificationLogEntry, error) { + c, err := h.store.GetNotificationChannel(ctx, channelID) + if err != nil { + return store.NotificationLogEntry{}, err + } + if p.Link == "" { + p.Link = h.baseURL + "/alerts/" + p.AlertID + } + return h.send(ctx, *c, p), nil +} + +// send builds the channel impl, delivers the payload, and persists a +// notification_log row regardless of success or failure. +func (h *Hub) send(ctx context.Context, c store.NotificationChannel, p Payload) store.NotificationLogEntry { + ch, buildErr := h.buildChannel(c) + logEntry := store.NotificationLogEntry{ + ID: newID(), + ChannelID: c.ID, + Event: string(p.Event), + FiredAt: time.Now().UTC(), + } + if p.AlertID != "" { + aid := p.AlertID + logEntry.AlertID = &aid + } + if buildErr != nil { + errStr := buildErr.Error() + logEntry.OK = false + logEntry.Error = &errStr + _ = h.store.AppendNotificationLog(ctx, logEntry) + return logEntry + } + + code, latency, sendErr := ch.Send(ctx, p) + statusCode := code + latencyMS := int(latency.Milliseconds()) + logEntry.StatusCode = &statusCode + logEntry.LatencyMS = &latencyMS + if sendErr != nil { + errStr := sendErr.Error() + logEntry.OK = false + logEntry.Error = &errStr + } else { + logEntry.OK = true + } + if err := h.store.AppendNotificationLog(ctx, logEntry); err != nil { + slog.Warn("notification: persist log", "err", err) + } + return logEntry +} + +// buildChannel decrypts the channel config and returns a concrete +// Channel implementation for the channel's kind. +func (h *Hub) buildChannel(row store.NotificationChannel) (Channel, error) { + plain, err := h.aead.Decrypt(string(row.Config), []byte("notification-channel:"+row.ID)) + if err != nil { + return nil, err + } + switch row.Kind { + case "webhook": + var cfg WebhookConfig + if err := json.Unmarshal(plain, &cfg); err != nil { + return nil, err + } + return NewWebhookChannel(cfg), nil + case "ntfy": + var cfg NtfyConfig + if err := json.Unmarshal(plain, &cfg); err != nil { + return nil, err + } + dp := "" + if row.DefaultPriority != nil { + dp = *row.DefaultPriority + } + return NewNtfyChannel(cfg, dp), nil + case "smtp": + var cfg SMTPConfig + if err := json.Unmarshal(plain, &cfg); err != nil { + return nil, err + } + return NewSMTPChannel(cfg, h.msgIDDomain), nil + } + return nil, errUnknownKind(row.Kind) +} + +// newID returns a 32-hex-char random identifier for notification_log rows. +func newID() string { + var b [16]byte + _, _ = rand.Read(b[:]) + return hex.EncodeToString(b[:]) +} + +// extractDomain strips the scheme and path from baseURL, leaving only +// the host[:port] component. Used as the right-hand side of SMTP +// Message-IDs. +func extractDomain(baseURL string) string { + s := baseURL + if i := indexOf(s, "://"); i >= 0 { + s = s[i+3:] + } + if i := indexOf(s, "/"); i >= 0 { + s = s[:i] + } + if s == "" { + return "restic-manager.local" + } + return s +} + +// indexOf returns the index of the first occurrence of sub in s, or -1. +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} + +type errUnknownKind string + +func (e errUnknownKind) Error() string { return "notification: unknown kind: " + string(e) } diff --git a/internal/notification/hub_test.go b/internal/notification/hub_test.go new file mode 100644 index 0000000..89a2389 --- /dev/null +++ b/internal/notification/hub_test.go @@ -0,0 +1,99 @@ +package notification + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "gitea.dcglab.co.uk/steve/restic-manager/internal/crypto" + "gitea.dcglab.co.uk/steve/restic-manager/internal/store" +) + +func setupHub(t *testing.T) (*Hub, *store.Store) { + t.Helper() + dir := t.TempDir() + st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db")) + if err != nil { + t.Fatalf("store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + keyPath := filepath.Join(dir, "secret.key") + _ = crypto.GenerateKeyFile(keyPath) + key, _ := crypto.LoadKeyFromFile(keyPath) + aead, _ := crypto.NewAEAD(key) + return NewHub(st, aead, "https://rm.example"), st +} + +func TestHubDispatchRecordsLogEntries(t *testing.T) { + t.Parallel() + hub, st := setupHub(t) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + })) + defer srv.Close() + + cfg, _ := json.Marshal(WebhookConfig{URL: srv.URL}) + enc, err := hub.aead.Encrypt(cfg, []byte("notification-channel:test-ch")) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + if err := st.CreateNotificationChannel(context.Background(), store.NotificationChannel{ + ID: "test-ch", Kind: "webhook", Name: "test", Enabled: true, + Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + }); err != nil { + t.Fatalf("create channel: %v", err) + } + + hub.Dispatch(context.Background(), Payload{ + Event: EventRaised, + Severity: "warning", + Kind: "backup_failed", + HostName: "alfa-01", + Message: "x", + RaisedAt: time.Now().UTC(), + }) + + // Verify a log row landed with ok=1. + var n int + if err := st.DB().QueryRow( + `SELECT COUNT(*) FROM notification_log WHERE channel_id = ? AND ok = 1`, "test-ch", + ).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 1 { + t.Fatalf("expected 1 log row, got %d", n) + } +} + +func TestHubSkipsDisabledChannels(t *testing.T) { + t.Parallel() + hub, st := setupHub(t) + + cfg, _ := json.Marshal(WebhookConfig{URL: "http://no-such-host.invalid"}) + enc, _ := hub.aead.Encrypt(cfg, []byte("notification-channel:dis")) + _ = st.CreateNotificationChannel(context.Background(), store.NotificationChannel{ + ID: "dis", Kind: "webhook", Name: "off", Enabled: false, + Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(), + }) + + hub.Dispatch(context.Background(), Payload{ + Event: EventRaised, + AlertID: "x", + Severity: "warning", + Kind: "backup_failed", + HostName: "h", + Message: "m", + RaisedAt: time.Now().UTC(), + }) + + var n int + _ = st.DB().QueryRow(`SELECT COUNT(*) FROM notification_log`).Scan(&n) + if n != 0 { + t.Errorf("disabled channel produced log rows: %d", n) + } +}