notification: Hub fan-out + log writer
This commit is contained in:
@@ -0,0 +1,187 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// Hub fans Payload events out to every enabled channel and persists
|
||||
// the result to notification_log. One Hub per process; thread-safe.
|
||||
type Hub struct {
|
||||
store *store.Store
|
||||
aead *crypto.AEAD
|
||||
baseURL string // e.g. https://restic-manager.example
|
||||
msgIDDomain string // hostname extracted from baseURL for SMTP Message-ID
|
||||
}
|
||||
|
||||
// NewHub constructs a Hub. baseURL is the public root of the server
|
||||
// (used to build /alerts/<id> links and the SMTP Message-ID domain).
|
||||
func NewHub(st *store.Store, aead *crypto.AEAD, baseURL string) *Hub {
|
||||
return &Hub{
|
||||
store: st,
|
||||
aead: aead,
|
||||
baseURL: baseURL,
|
||||
msgIDDomain: extractDomain(baseURL),
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch fans out to every enabled channel. Best-effort — failures
|
||||
// are logged to notification_log but do not propagate to the caller.
|
||||
// Each channel runs in its own goroutine; Dispatch returns only when
|
||||
// all goroutines have settled, so the caller can block briefly for
|
||||
// the test-button case.
|
||||
func (h *Hub) Dispatch(ctx context.Context, p Payload) {
|
||||
chans, err := h.store.ListEnabledNotificationChannels(ctx)
|
||||
if err != nil {
|
||||
slog.Error("notification: list channels", "err", err)
|
||||
return
|
||||
}
|
||||
// Stamp the alert link if the caller left it empty.
|
||||
if p.Link == "" {
|
||||
p.Link = h.baseURL + "/alerts/" + p.AlertID
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, c := range chans {
|
||||
wg.Add(1)
|
||||
go func(c store.NotificationChannel) {
|
||||
defer wg.Done()
|
||||
h.send(ctx, c, p)
|
||||
}(c)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// DispatchOne fires a single channel — used by the "Send test
|
||||
// notification" button. Returns the log entry that was persisted so
|
||||
// the handler can render the result inline.
|
||||
func (h *Hub) DispatchOne(ctx context.Context, channelID string, p Payload) (store.NotificationLogEntry, error) {
|
||||
c, err := h.store.GetNotificationChannel(ctx, channelID)
|
||||
if err != nil {
|
||||
return store.NotificationLogEntry{}, err
|
||||
}
|
||||
if p.Link == "" {
|
||||
p.Link = h.baseURL + "/alerts/" + p.AlertID
|
||||
}
|
||||
return h.send(ctx, *c, p), nil
|
||||
}
|
||||
|
||||
// send builds the channel impl, delivers the payload, and persists a
|
||||
// notification_log row regardless of success or failure.
|
||||
func (h *Hub) send(ctx context.Context, c store.NotificationChannel, p Payload) store.NotificationLogEntry {
|
||||
ch, buildErr := h.buildChannel(c)
|
||||
logEntry := store.NotificationLogEntry{
|
||||
ID: newID(),
|
||||
ChannelID: c.ID,
|
||||
Event: string(p.Event),
|
||||
FiredAt: time.Now().UTC(),
|
||||
}
|
||||
if p.AlertID != "" {
|
||||
aid := p.AlertID
|
||||
logEntry.AlertID = &aid
|
||||
}
|
||||
if buildErr != nil {
|
||||
errStr := buildErr.Error()
|
||||
logEntry.OK = false
|
||||
logEntry.Error = &errStr
|
||||
_ = h.store.AppendNotificationLog(ctx, logEntry)
|
||||
return logEntry
|
||||
}
|
||||
|
||||
code, latency, sendErr := ch.Send(ctx, p)
|
||||
statusCode := code
|
||||
latencyMS := int(latency.Milliseconds())
|
||||
logEntry.StatusCode = &statusCode
|
||||
logEntry.LatencyMS = &latencyMS
|
||||
if sendErr != nil {
|
||||
errStr := sendErr.Error()
|
||||
logEntry.OK = false
|
||||
logEntry.Error = &errStr
|
||||
} else {
|
||||
logEntry.OK = true
|
||||
}
|
||||
if err := h.store.AppendNotificationLog(ctx, logEntry); err != nil {
|
||||
slog.Warn("notification: persist log", "err", err)
|
||||
}
|
||||
return logEntry
|
||||
}
|
||||
|
||||
// buildChannel decrypts the channel config and returns a concrete
|
||||
// Channel implementation for the channel's kind.
|
||||
func (h *Hub) buildChannel(row store.NotificationChannel) (Channel, error) {
|
||||
plain, err := h.aead.Decrypt(string(row.Config), []byte("notification-channel:"+row.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch row.Kind {
|
||||
case "webhook":
|
||||
var cfg WebhookConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewWebhookChannel(cfg), nil
|
||||
case "ntfy":
|
||||
var cfg NtfyConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dp := ""
|
||||
if row.DefaultPriority != nil {
|
||||
dp = *row.DefaultPriority
|
||||
}
|
||||
return NewNtfyChannel(cfg, dp), nil
|
||||
case "smtp":
|
||||
var cfg SMTPConfig
|
||||
if err := json.Unmarshal(plain, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSMTPChannel(cfg, h.msgIDDomain), nil
|
||||
}
|
||||
return nil, errUnknownKind(row.Kind)
|
||||
}
|
||||
|
||||
// newID returns a 32-hex-char random identifier for notification_log rows.
|
||||
func newID() string {
|
||||
var b [16]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// extractDomain strips the scheme and path from baseURL, leaving only
|
||||
// the host[:port] component. Used as the right-hand side of SMTP
|
||||
// Message-IDs.
|
||||
func extractDomain(baseURL string) string {
|
||||
s := baseURL
|
||||
if i := indexOf(s, "://"); i >= 0 {
|
||||
s = s[i+3:]
|
||||
}
|
||||
if i := indexOf(s, "/"); i >= 0 {
|
||||
s = s[:i]
|
||||
}
|
||||
if s == "" {
|
||||
return "restic-manager.local"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// indexOf returns the index of the first occurrence of sub in s, or -1.
|
||||
func indexOf(s, sub string) int {
|
||||
for i := 0; i+len(sub) <= len(s); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type errUnknownKind string
|
||||
|
||||
func (e errUnknownKind) Error() string { return "notification: unknown kind: " + string(e) }
|
||||
@@ -0,0 +1,99 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/crypto"
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
func setupHub(t *testing.T) (*Hub, *store.Store) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
st, err := store.Open(context.Background(), filepath.Join(dir, "rm.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = st.Close() })
|
||||
keyPath := filepath.Join(dir, "secret.key")
|
||||
_ = crypto.GenerateKeyFile(keyPath)
|
||||
key, _ := crypto.LoadKeyFromFile(keyPath)
|
||||
aead, _ := crypto.NewAEAD(key)
|
||||
return NewHub(st, aead, "https://rm.example"), st
|
||||
}
|
||||
|
||||
func TestHubDispatchRecordsLogEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
hub, st := setupHub(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg, _ := json.Marshal(WebhookConfig{URL: srv.URL})
|
||||
enc, err := hub.aead.Encrypt(cfg, []byte("notification-channel:test-ch"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
if err := st.CreateNotificationChannel(context.Background(), store.NotificationChannel{
|
||||
ID: "test-ch", Kind: "webhook", Name: "test", Enabled: true,
|
||||
Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
|
||||
}); err != nil {
|
||||
t.Fatalf("create channel: %v", err)
|
||||
}
|
||||
|
||||
hub.Dispatch(context.Background(), Payload{
|
||||
Event: EventRaised,
|
||||
Severity: "warning",
|
||||
Kind: "backup_failed",
|
||||
HostName: "alfa-01",
|
||||
Message: "x",
|
||||
RaisedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
// Verify a log row landed with ok=1.
|
||||
var n int
|
||||
if err := st.DB().QueryRow(
|
||||
`SELECT COUNT(*) FROM notification_log WHERE channel_id = ? AND ok = 1`, "test-ch",
|
||||
).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected 1 log row, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubSkipsDisabledChannels(t *testing.T) {
|
||||
t.Parallel()
|
||||
hub, st := setupHub(t)
|
||||
|
||||
cfg, _ := json.Marshal(WebhookConfig{URL: "http://no-such-host.invalid"})
|
||||
enc, _ := hub.aead.Encrypt(cfg, []byte("notification-channel:dis"))
|
||||
_ = st.CreateNotificationChannel(context.Background(), store.NotificationChannel{
|
||||
ID: "dis", Kind: "webhook", Name: "off", Enabled: false,
|
||||
Config: []byte(enc), CreatedAt: time.Now().UTC(), UpdatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
hub.Dispatch(context.Background(), Payload{
|
||||
Event: EventRaised,
|
||||
AlertID: "x",
|
||||
Severity: "warning",
|
||||
Kind: "backup_failed",
|
||||
HostName: "h",
|
||||
Message: "m",
|
||||
RaisedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
var n int
|
||||
_ = st.DB().QueryRow(`SELECT COUNT(*) FROM notification_log`).Scan(&n)
|
||||
if n != 0 {
|
||||
t.Errorf("disabled channel produced log rows: %d", n)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user