From c0de973c535db559eed14223fab340c90d9002bd Mon Sep 17 00:00:00 2001 From: Shankar Date: Mon, 23 Mar 2026 17:58:14 -0400 Subject: [PATCH] feat: M19 API audit log + M16a notifier connectors (Slack, Teams, PagerDuty, OpsGenie) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit M19: HTTP middleware records every API call to the immutable audit trail with method, path, actor, SHA-256 body hash, status, and latency. Best-effort async recording via goroutine. Health/ready probes excluded. M16a: Four pluggable notifier connectors — Slack (incoming webhook), Teams (MessageCard), PagerDuty (Events API v2), OpsGenie (Alert API v2). Each enabled by config env var. 30 new tests across middleware and connectors. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 2 +- cmd/server/main.go | 57 ++- internal/api/middleware/audit.go | 127 +++++++ internal/api/middleware/audit_test.go | 339 ++++++++++++++++++ internal/config/config.go | 24 ++ .../connector/notifier/opsgenie/opsgenie.go | 91 +++++ .../notifier/opsgenie/opsgenie_test.go | 128 +++++++ .../connector/notifier/pagerduty/pagerduty.go | 100 ++++++ .../notifier/pagerduty/pagerduty_test.go | 144 ++++++++ internal/connector/notifier/slack/slack.go | 92 +++++ .../connector/notifier/slack/slack_test.go | 107 ++++++ internal/connector/notifier/teams/teams.go | 93 +++++ .../connector/notifier/teams/teams_test.go | 91 +++++ internal/domain/notification.go | 9 +- 14 files changed, 1399 insertions(+), 5 deletions(-) create mode 100644 internal/api/middleware/audit.go create mode 100644 internal/api/middleware/audit_test.go create mode 100644 internal/connector/notifier/opsgenie/opsgenie.go create mode 100644 internal/connector/notifier/opsgenie/opsgenie_test.go create mode 100644 internal/connector/notifier/pagerduty/pagerduty.go create mode 100644 internal/connector/notifier/pagerduty/pagerduty_test.go create mode 100644 internal/connector/notifier/slack/slack.go create mode 100644 internal/connector/notifier/slack/slack_test.go create mode 100644 internal/connector/notifier/teams/teams.go create mode 100644 internal/connector/notifier/teams/teams_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54ebd86..425ef15 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: Go Test with Coverage run: | - go test ./internal/service/... ./internal/api/handler/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/mcp/... -count=1 -cover -coverprofile=coverage.out + go test ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/connector/notifier/... ./internal/mcp/... -count=1 -cover -coverprofile=coverage.out - name: Check Coverage Thresholds run: | diff --git a/cmd/server/main.go b/cmd/server/main.go index 9412992..e336294 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -16,9 +16,14 @@ import ( "github.com/shankar0123/certctl/internal/api/middleware" "github.com/shankar0123/certctl/internal/api/router" "github.com/shankar0123/certctl/internal/config" + "github.com/shankar0123/certctl/internal/domain" acmeissuer "github.com/shankar0123/certctl/internal/connector/issuer/acme" "github.com/shankar0123/certctl/internal/connector/issuer/local" stepcaissuer "github.com/shankar0123/certctl/internal/connector/issuer/stepca" + notifyopsgenie "github.com/shankar0123/certctl/internal/connector/notifier/opsgenie" + notifypagerduty "github.com/shankar0123/certctl/internal/connector/notifier/pagerduty" + notifyslack "github.com/shankar0123/certctl/internal/connector/notifier/slack" + notifyteams "github.com/shankar0123/certctl/internal/connector/notifier/teams" "github.com/shankar0123/certctl/internal/repository/postgres" "github.com/shankar0123/certctl/internal/scheduler" "github.com/shankar0123/certctl/internal/service" @@ -131,7 +136,43 @@ func main() { auditService := service.NewAuditService(auditRepo) policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certificateRepo, policyService, auditService) - notificationService := service.NewNotificationService(notificationRepo, make(map[string]service.Notifier)) + notifierRegistry := make(map[string]service.Notifier) + + // Wire notifier connectors from config + if cfg.Notifiers.SlackWebhookURL != "" { + slackNotifier := notifyslack.New(notifyslack.Config{ + WebhookURL: cfg.Notifiers.SlackWebhookURL, + ChannelOverride: cfg.Notifiers.SlackChannel, + Username: cfg.Notifiers.SlackUsername, + }) + notifierRegistry["Slack"] = slackNotifier + logger.Info("Slack notifier enabled") + } + if cfg.Notifiers.TeamsWebhookURL != "" { + teamsNotifier := notifyteams.New(notifyteams.Config{ + WebhookURL: cfg.Notifiers.TeamsWebhookURL, + }) + notifierRegistry["Teams"] = teamsNotifier + logger.Info("Teams notifier enabled") + } + if cfg.Notifiers.PagerDutyRoutingKey != "" { + pdNotifier := notifypagerduty.New(notifypagerduty.Config{ + RoutingKey: cfg.Notifiers.PagerDutyRoutingKey, + Severity: cfg.Notifiers.PagerDutySeverity, + }) + notifierRegistry["PagerDuty"] = pdNotifier + logger.Info("PagerDuty notifier enabled") + } + if cfg.Notifiers.OpsGenieAPIKey != "" { + ogNotifier := notifyopsgenie.New(notifyopsgenie.Config{ + APIKey: cfg.Notifiers.OpsGenieAPIKey, + Priority: cfg.Notifiers.OpsGeniePriority, + }) + notifierRegistry["OpsGenie"] = ogNotifier + logger.Info("OpsGenie notifier enabled") + } + + notificationService := service.NewNotificationService(notificationRepo, notifierRegistry) notificationService.SetOwnerRepo(ownerRepo) // Wire revocation dependencies into CertificateService @@ -231,12 +272,25 @@ func main() { structuredLogger := middleware.NewLogging(logger) + // API audit log middleware — records every API call to the audit trail + auditAdapter := middleware.NewAuditServiceAdapter( + func(ctx context.Context, actor string, actorType string, action string, resourceType string, resourceID string, details map[string]interface{}) error { + return auditService.RecordEvent(ctx, actor, domain.ActorType(actorType), action, resourceType, resourceID, details) + }, + ) + auditMiddleware := middleware.NewAuditLog(auditAdapter, middleware.AuditConfig{ + ExcludePaths: []string{"/health", "/ready"}, + Logger: logger, + }) + logger.Info("API audit logging enabled (excluding /health, /ready)") + middlewareStack := []func(http.Handler) http.Handler{ middleware.RequestID, structuredLogger, middleware.Recovery, corsMiddleware, authMiddleware, + auditMiddleware, } // Add rate limiter if enabled @@ -252,6 +306,7 @@ func main() { rateLimiter, corsMiddleware, authMiddleware, + auditMiddleware, } logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize) } diff --git a/internal/api/middleware/audit.go b/internal/api/middleware/audit.go new file mode 100644 index 0000000..a5e947a --- /dev/null +++ b/internal/api/middleware/audit.go @@ -0,0 +1,127 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" +) + +// AuditRecorder is the interface that the audit middleware uses to record API calls. +// This avoids importing the service package directly, maintaining dependency inversion. +type AuditRecorder interface { + RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error +} + +// AuditConfig holds configuration for the API audit logging middleware. +type AuditConfig struct { + // ExcludePaths are path prefixes to skip audit logging (e.g., "/health", "/ready"). + ExcludePaths []string + // Logger for audit middleware errors (audit recording failures shouldn't break requests). + Logger *slog.Logger +} + +// NewAuditLog creates a middleware that records every API call to the audit trail. +// It captures method, path, authenticated actor, request body hash, response status, and latency. +// Audit recording is best-effort — failures are logged but don't affect the HTTP response. +func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) http.Handler { + excludeSet := make(map[string]bool, len(cfg.ExcludePaths)) + for _, p := range cfg.ExcludePaths { + excludeSet[p] = true + } + + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip excluded paths (health, readiness probes) + for prefix := range excludeSet { + if strings.HasPrefix(r.URL.Path, prefix) { + next.ServeHTTP(w, r) + return + } + } + + start := time.Now() + + // Hash request body for audit (don't store raw bodies — security + size concerns) + bodyHash := "" + if r.Body != nil && r.Body != http.NoBody { + hasher := sha256.New() + body, err := io.ReadAll(r.Body) + if err == nil && len(body) > 0 { + hasher.Write(body) + bodyHash = hex.EncodeToString(hasher.Sum(nil))[:16] // truncated hash + // Restore the body for downstream handlers + r.Body = io.NopCloser(strings.NewReader(string(body))) + } + } + + // Extract actor from auth context + actor := "anonymous" + if user, ok := GetUser(r.Context()); ok && user != "" { + actor = user + } + + // Wrap response writer to capture status code + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + next.ServeHTTP(wrapped, r) + + latency := time.Since(start).Milliseconds() + + // Record audit event asynchronously (best-effort, don't block response) + go func() { + if err := recorder.RecordAPICall( + context.Background(), + r.Method, + r.URL.Path, + actor, + bodyHash, + wrapped.statusCode, + latency, + ); err != nil { + logger.Error("failed to record API audit event", + "error", err, + "method", r.Method, + "path", r.URL.Path, + ) + } + }() + }) + } +} + +// AuditServiceAdapter adapts the AuditService to the AuditRecorder interface. +// This keeps the middleware decoupled from the service package. +type AuditServiceAdapter struct { + recordFn func(ctx context.Context, actor string, actorType string, action string, resourceType string, resourceID string, details map[string]interface{}) error +} + +// NewAuditServiceAdapter creates an adapter that bridges the middleware AuditRecorder +// interface to the service layer's RecordEvent method. +func NewAuditServiceAdapter(recordFn func(ctx context.Context, actor string, actorType string, action string, resourceType string, resourceID string, details map[string]interface{}) error) *AuditServiceAdapter { + return &AuditServiceAdapter{recordFn: recordFn} +} + +// RecordAPICall implements AuditRecorder by translating API call data into an audit event. +func (a *AuditServiceAdapter) RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error { + details := map[string]interface{}{ + "method": method, + "path": path, + "body_hash": bodyHash, + "status": status, + "latency_ms": latencyMs, + } + + action := fmt.Sprintf("api_%s", strings.ToLower(method)) + return a.recordFn(ctx, actor, "User", action, "api", path, details) +} diff --git a/internal/api/middleware/audit_test.go b/internal/api/middleware/audit_test.go new file mode 100644 index 0000000..400c568 --- /dev/null +++ b/internal/api/middleware/audit_test.go @@ -0,0 +1,339 @@ +package middleware + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// mockAuditRecorder captures RecordAPICall invocations for testing. +type mockAuditRecorder struct { + mu sync.Mutex + calls []auditCall + err error // if non-nil, RecordAPICall returns this +} + +type auditCall struct { + Method string + Path string + Actor string + BodyHash string + Status int + LatencyMs int64 +} + +func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.calls = append(m.calls, auditCall{ + Method: method, + Path: path, + Actor: actor, + BodyHash: bodyHash, + Status: status, + LatencyMs: latencyMs, + }) + return m.err +} + +func (m *mockAuditRecorder) getCalls() []auditCall { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]auditCall, len(m.calls)) + copy(out, m.calls) + return out +} + +func TestAuditLog_RecordsAPICall(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // Audit recording is async — give goroutine time to complete + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].Method != "GET" { + t.Errorf("expected method GET, got %s", calls[0].Method) + } + if calls[0].Path != "/api/v1/certificates" { + t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path) + } + if calls[0].Actor != "anonymous" { + t.Errorf("expected actor anonymous, got %s", calls[0].Actor) + } + if calls[0].Status != 200 { + t.Errorf("expected status 200, got %d", calls[0].Status) + } +} + +func TestAuditLog_CapturesStatusCode(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certs/mc-nonexistent", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].Status != 404 { + t.Errorf("expected status 404, got %d", calls[0].Status) + } +} + +func TestAuditLog_ExcludesHealth(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{ + ExcludePaths: []string{"/health", "/ready"}, + }) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Health endpoint — should be excluded + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Ready endpoint — should be excluded + req2 := httptest.NewRequest(http.MethodGet, "/ready", nil) + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + + // API endpoint — should be recorded + req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + rr3 := httptest.NewRecorder() + handler.ServeHTTP(rr3, req3) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call (health/ready excluded), got %d", len(calls)) + } + if calls[0].Path != "/api/v1/certificates" { + t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path) + } +} + +func TestAuditLog_HashesRequestBody(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + // Handler verifies body was restored + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if string(body) != `{"name":"test"}` { + t.Errorf("body was not restored: got %q", string(body)) + } + w.WriteHeader(http.StatusCreated) + })) + + body := strings.NewReader(`{"name":"test"}`) + req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", body) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + // Body hash should be a 16-char hex string (truncated SHA-256) + if len(calls[0].BodyHash) != 16 { + t.Errorf("expected 16-char body hash, got %q (len=%d)", calls[0].BodyHash, len(calls[0].BodyHash)) + } + if calls[0].Status != 201 { + t.Errorf("expected status 201, got %d", calls[0].Status) + } +} + +func TestAuditLog_EmptyBodyNoHash(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].BodyHash != "" { + t.Errorf("expected empty body hash for GET, got %q", calls[0].BodyHash) + } +} + +func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/mc-1", nil) + // Simulate auth middleware having set the user in context + ctx := context.WithValue(req.Context(), UserKey{}, "api-key-user") + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].Actor != "api-key-user" { + t.Errorf("expected actor api-key-user, got %s", calls[0].Actor) + } + if calls[0].Method != "DELETE" { + t.Errorf("expected method DELETE, got %s", calls[0].Method) + } +} + +func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) { + recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/stats/summary", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Response should still be 200 even though audit recording fails + if rr.Code != http.StatusOK { + t.Errorf("expected 200 despite recorder error, got %d", rr.Code) + } +} + +func TestAuditLog_CapturesLatency(t *testing.T) { + recorder := &mockAuditRecorder{} + mw := NewAuditLog(recorder, AuditConfig{}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + time.Sleep(50 * time.Millisecond) + + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].LatencyMs < 10 { + t.Errorf("expected latency >= 10ms, got %dms", calls[0].LatencyMs) + } +} + +func TestAuditServiceAdapter_TranslatesCallToEvent(t *testing.T) { + var capturedActor, capturedActorType, capturedAction, capturedResourceType, capturedResourceID string + var capturedDetails map[string]interface{} + + adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error { + capturedActor = actor + capturedActorType = actorType + capturedAction = action + capturedResourceType = resourceType + capturedResourceID = resourceID + capturedDetails = details + return nil + }) + + err := adapter.RecordAPICall(context.Background(), "POST", "/api/v1/certificates", "admin", "abc123", 201, 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if capturedActor != "admin" { + t.Errorf("expected actor admin, got %s", capturedActor) + } + if capturedActorType != "User" { + t.Errorf("expected actorType User, got %s", capturedActorType) + } + if capturedAction != "api_post" { + t.Errorf("expected action api_post, got %s", capturedAction) + } + if capturedResourceType != "api" { + t.Errorf("expected resourceType api, got %s", capturedResourceType) + } + if capturedResourceID != "/api/v1/certificates" { + t.Errorf("expected resourceID /api/v1/certificates, got %s", capturedResourceID) + } + if capturedDetails["method"] != "POST" { + t.Errorf("expected details.method POST, got %v", capturedDetails["method"]) + } + if capturedDetails["status"] != 201 { + t.Errorf("expected details.status 201, got %v", capturedDetails["status"]) + } + if capturedDetails["latency_ms"] != int64(42) { + t.Errorf("expected details.latency_ms 42, got %v", capturedDetails["latency_ms"]) + } + if capturedDetails["body_hash"] != "abc123" { + t.Errorf("expected details.body_hash abc123, got %v", capturedDetails["body_hash"]) + } +} + +func TestAuditServiceAdapter_PropagatesError(t *testing.T) { + adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error { + return fmt.Errorf("database error") + }) + + err := adapter.RecordAPICall(context.Background(), "GET", "/api/v1/agents", "user", "", 200, 5) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "database error") { + t.Errorf("expected database error, got %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index ba288a0..7878655 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,20 @@ type Config struct { CORS CORSConfig Keygen KeygenConfig CA CAConfig + Notifiers NotifierConfig +} + +// NotifierConfig contains configuration for notification connectors. +// Each notifier is enabled by setting its required env var (webhook URL or API key). +type NotifierConfig struct { + SlackWebhookURL string + SlackChannel string + SlackUsername string + TeamsWebhookURL string + PagerDutyRoutingKey string + PagerDutySeverity string + OpsGenieAPIKey string + OpsGeniePriority string } // KeygenConfig controls where private keys are generated. @@ -146,6 +160,16 @@ func Load() (*Config, error) { CertPath: getEnv("CERTCTL_CA_CERT_PATH", ""), KeyPath: getEnv("CERTCTL_CA_KEY_PATH", ""), }, + Notifiers: NotifierConfig{ + SlackWebhookURL: getEnv("CERTCTL_SLACK_WEBHOOK_URL", ""), + SlackChannel: getEnv("CERTCTL_SLACK_CHANNEL", ""), + SlackUsername: getEnv("CERTCTL_SLACK_USERNAME", "certctl"), + TeamsWebhookURL: getEnv("CERTCTL_TEAMS_WEBHOOK_URL", ""), + PagerDutyRoutingKey: getEnv("CERTCTL_PAGERDUTY_ROUTING_KEY", ""), + PagerDutySeverity: getEnv("CERTCTL_PAGERDUTY_SEVERITY", "warning"), + OpsGenieAPIKey: getEnv("CERTCTL_OPSGENIE_API_KEY", ""), + OpsGeniePriority: getEnv("CERTCTL_OPSGENIE_PRIORITY", "P3"), + }, } if err := cfg.Validate(); err != nil { diff --git a/internal/connector/notifier/opsgenie/opsgenie.go b/internal/connector/notifier/opsgenie/opsgenie.go new file mode 100644 index 0000000..232f3f8 --- /dev/null +++ b/internal/connector/notifier/opsgenie/opsgenie.go @@ -0,0 +1,91 @@ +package opsgenie + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const alertAPIURL = "https://api.opsgenie.com/v2/alerts" + +// Config holds configuration for the OpsGenie notifier. +type Config struct { + // APIKey is the OpsGenie API integration key. + APIKey string `json:"api_key"` + // Priority is the default alert priority (P1-P5). Defaults to "P3". + Priority string `json:"priority,omitempty"` + // Tags are default tags applied to all alerts. + Tags []string `json:"tags,omitempty"` +} + +// Notifier sends notifications to OpsGenie via the Alert API. +type Notifier struct { + config Config + httpClient *http.Client +} + +// New creates a new OpsGenie notifier. +func New(config Config) *Notifier { + if config.Priority == "" { + config.Priority = "P3" + } + return &Notifier{ + config: config, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// Channel returns the channel identifier. +func (n *Notifier) Channel() string { + return "OpsGenie" +} + +// Send delivers a notification to OpsGenie as an alert. +func (n *Notifier) Send(ctx context.Context, recipient string, subject string, body string) error { + alert := ogAlert{ + Message: subject, + Description: body, + Priority: n.config.Priority, + Source: "certctl", + Tags: n.config.Tags, + } + + jsonBytes, err := json.Marshal(alert) + if err != nil { + return fmt.Errorf("opsgenie: failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, alertAPIURL, bytes.NewReader(jsonBytes)) + if err != nil { + return fmt.Errorf("opsgenie: failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "GenieKey "+n.config.APIKey) + + resp, err := n.httpClient.Do(req) + if err != nil { + return fmt.Errorf("opsgenie: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("opsgenie: API returned HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +type ogAlert struct { + Message string `json:"message"` + Description string `json:"description,omitempty"` + Priority string `json:"priority,omitempty"` + Source string `json:"source,omitempty"` + Tags []string `json:"tags,omitempty"` +} diff --git a/internal/connector/notifier/opsgenie/opsgenie_test.go b/internal/connector/notifier/opsgenie/opsgenie_test.go new file mode 100644 index 0000000..c4e62a5 --- /dev/null +++ b/internal/connector/notifier/opsgenie/opsgenie_test.go @@ -0,0 +1,128 @@ +package opsgenie + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestOpsGenie_Channel(t *testing.T) { + n := New(Config{APIKey: "test-key"}) + if n.Channel() != "OpsGenie" { + t.Errorf("expected channel OpsGenie, got %s", n.Channel()) + } +} + +func TestOpsGenie_DefaultPriority(t *testing.T) { + n := New(Config{APIKey: "test-key"}) + if n.config.Priority != "P3" { + t.Errorf("expected default priority P3, got %s", n.config.Priority) + } +} + +func TestOpsGenie_CustomPriority(t *testing.T) { + n := New(Config{APIKey: "test-key", Priority: "P1"}) + if n.config.Priority != "P1" { + t.Errorf("expected priority P1, got %s", n.config.Priority) + } +} + +func TestOpsGenie_SendSuccess(t *testing.T) { + var receivedAlert ogAlert + var receivedAuthHeader string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + receivedAuthHeader = r.Header.Get("Authorization") + if err := json.NewDecoder(r.Body).Decode(&receivedAlert); err != nil { + t.Fatalf("failed to decode payload: %v", err) + } + w.WriteHeader(http.StatusAccepted) + })) + defer server.Close() + + n := New(Config{ + APIKey: "test-api-key-123", + Priority: "P2", + Tags: []string{"certctl", "production"}, + }) + // Override HTTP client to hit test server + n.httpClient = &http.Client{Transport: &urlRewriteTransport{target: server.URL, transport: http.DefaultTransport}} + + err := n.Send(context.Background(), "ops-team", "Key Compromise", "Certificate mc-api-prod may have compromised private key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedAuthHeader != "GenieKey test-api-key-123" { + t.Errorf("expected GenieKey auth header, got %s", receivedAuthHeader) + } + if receivedAlert.Message != "Key Compromise" { + t.Errorf("expected message 'Key Compromise', got %s", receivedAlert.Message) + } + if receivedAlert.Description != "Certificate mc-api-prod may have compromised private key" { + t.Errorf("expected description with cert details, got %s", receivedAlert.Description) + } + if receivedAlert.Priority != "P2" { + t.Errorf("expected priority P2, got %s", receivedAlert.Priority) + } + if receivedAlert.Source != "certctl" { + t.Errorf("expected source certctl, got %s", receivedAlert.Source) + } + if len(receivedAlert.Tags) != 2 || receivedAlert.Tags[0] != "certctl" { + t.Errorf("expected tags [certctl, production], got %v", receivedAlert.Tags) + } +} + +func TestOpsGenie_SendHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message":"API key is invalid"}`)) + })) + defer server.Close() + + n := New(Config{APIKey: "bad-key"}) + n.httpClient = &http.Client{Transport: &urlRewriteTransport{target: server.URL, transport: http.DefaultTransport}} + + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 401") { + t.Errorf("expected HTTP 401 in error, got %v", err) + } +} + +func TestOpsGenie_SendConnectionError(t *testing.T) { + n := New(Config{APIKey: "test-key"}) + n.httpClient = &http.Client{Transport: &urlRewriteTransport{target: "http://127.0.0.1:1", transport: http.DefaultTransport}} + + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected connection error, got nil") + } + if !strings.Contains(err.Error(), "request failed") { + t.Errorf("expected 'request failed' in error, got %v", err) + } +} + +// urlRewriteTransport redirects all requests to a test server URL. +type urlRewriteTransport struct { + target string + transport http.RoundTripper +} + +func (t *urlRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.URL.Scheme = "http" + req.URL.Host = strings.TrimPrefix(t.target, "http://") + return t.transport.RoundTrip(req) +} diff --git a/internal/connector/notifier/pagerduty/pagerduty.go b/internal/connector/notifier/pagerduty/pagerduty.go new file mode 100644 index 0000000..728ed30 --- /dev/null +++ b/internal/connector/notifier/pagerduty/pagerduty.go @@ -0,0 +1,100 @@ +package pagerduty + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const eventsAPIURL = "https://events.pagerduty.com/v2/enqueue" + +// Config holds configuration for the PagerDuty notifier. +type Config struct { + // RoutingKey is the PagerDuty Events API v2 integration/routing key. + RoutingKey string `json:"routing_key"` + // Severity is the default event severity (critical, error, warning, info). + // Defaults to "warning" if not set. + Severity string `json:"severity,omitempty"` +} + +// Notifier sends notifications to PagerDuty via the Events API v2. +type Notifier struct { + config Config + httpClient *http.Client +} + +// New creates a new PagerDuty notifier. +func New(config Config) *Notifier { + if config.Severity == "" { + config.Severity = "warning" + } + return &Notifier{ + config: config, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// Channel returns the channel identifier. +func (n *Notifier) Channel() string { + return "PagerDuty" +} + +// Send delivers a notification to PagerDuty as a trigger event. +func (n *Notifier) Send(ctx context.Context, recipient string, subject string, body string) error { + event := pdEvent{ + RoutingKey: n.config.RoutingKey, + EventAction: "trigger", + Payload: pdPayload{ + Summary: subject, + Severity: n.config.Severity, + Source: "certctl", + CustomDetails: map[string]string{ + "body": body, + "recipient": recipient, + }, + }, + } + + jsonBytes, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("pagerduty: failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, eventsAPIURL, bytes.NewReader(jsonBytes)) + if err != nil { + return fmt.Errorf("pagerduty: failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := n.httpClient.Do(req) + if err != nil { + return fmt.Errorf("pagerduty: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("pagerduty: API returned HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +type pdEvent struct { + RoutingKey string `json:"routing_key"` + EventAction string `json:"event_action"` + Payload pdPayload `json:"payload"` +} + +type pdPayload struct { + Summary string `json:"summary"` + Severity string `json:"severity"` + Source string `json:"source"` + CustomDetails map[string]string `json:"custom_details,omitempty"` +} diff --git a/internal/connector/notifier/pagerduty/pagerduty_test.go b/internal/connector/notifier/pagerduty/pagerduty_test.go new file mode 100644 index 0000000..287ede1 --- /dev/null +++ b/internal/connector/notifier/pagerduty/pagerduty_test.go @@ -0,0 +1,144 @@ +package pagerduty + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestPagerDuty_Channel(t *testing.T) { + n := New(Config{RoutingKey: "test-key"}) + if n.Channel() != "PagerDuty" { + t.Errorf("expected channel PagerDuty, got %s", n.Channel()) + } +} + +func TestPagerDuty_DefaultSeverity(t *testing.T) { + n := New(Config{RoutingKey: "test-key"}) + if n.config.Severity != "warning" { + t.Errorf("expected default severity 'warning', got %s", n.config.Severity) + } +} + +func TestPagerDuty_CustomSeverity(t *testing.T) { + n := New(Config{RoutingKey: "test-key", Severity: "critical"}) + if n.config.Severity != "critical" { + t.Errorf("expected severity 'critical', got %s", n.config.Severity) + } +} + +func TestPagerDuty_SendSuccess(t *testing.T) { + var receivedEvent pdEvent + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + if err := json.NewDecoder(r.Body).Decode(&receivedEvent); err != nil { + t.Fatalf("failed to decode payload: %v", err) + } + w.WriteHeader(http.StatusAccepted) + })) + defer server.Close() + + // Override the events URL for testing — use a custom HTTP client that redirects + n := New(Config{RoutingKey: "test-routing-key", Severity: "error"}) + // We can't easily override the const URL, so test with a direct HTTP call approach. + // Instead, test the payload structure by calling Send with a mock server. + // We need to make the notifier use our test server URL. + // The simplest way: create the notifier, then manually set the URL by using the test server. + // Since eventsAPIURL is a const, we'll test by replacing the http client's transport. + + // Alternative approach: just test that the method constructs the right payload + // by using a custom transport that intercepts the request. + n.httpClient = server.Client() + + // For this test, we need to override the target URL. Since it's a package-level const, + // we'll create a custom RoundTripper that redirects to our test server. + originalURL := eventsAPIURL + _ = originalURL // just to avoid unused var in case we reference it + + transport := &urlRewriteTransport{ + target: server.URL, + transport: http.DefaultTransport, + } + n.httpClient = &http.Client{Transport: transport} + + err := n.Send(context.Background(), "oncall@example.com", "Cert Expired", "mc-api-prod has expired") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedEvent.RoutingKey != "test-routing-key" { + t.Errorf("expected routing key test-routing-key, got %s", receivedEvent.RoutingKey) + } + if receivedEvent.EventAction != "trigger" { + t.Errorf("expected event action trigger, got %s", receivedEvent.EventAction) + } + if receivedEvent.Payload.Summary != "Cert Expired" { + t.Errorf("expected summary 'Cert Expired', got %s", receivedEvent.Payload.Summary) + } + if receivedEvent.Payload.Severity != "error" { + t.Errorf("expected severity error, got %s", receivedEvent.Payload.Severity) + } + if receivedEvent.Payload.Source != "certctl" { + t.Errorf("expected source certctl, got %s", receivedEvent.Payload.Source) + } + if receivedEvent.Payload.CustomDetails["body"] != "mc-api-prod has expired" { + t.Errorf("expected body in custom_details, got %v", receivedEvent.Payload.CustomDetails) + } + if receivedEvent.Payload.CustomDetails["recipient"] != "oncall@example.com" { + t.Errorf("expected recipient in custom_details, got %v", receivedEvent.Payload.CustomDetails) + } +} + +func TestPagerDuty_SendHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"status":"invalid","message":"bad routing key"}`)) + })) + defer server.Close() + + n := New(Config{RoutingKey: "bad-key"}) + n.httpClient = &http.Client{Transport: &urlRewriteTransport{target: server.URL, transport: http.DefaultTransport}} + + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 400") { + t.Errorf("expected HTTP 400 in error, got %v", err) + } +} + +func TestPagerDuty_SendConnectionError(t *testing.T) { + n := New(Config{RoutingKey: "test-key"}) + n.httpClient = &http.Client{Transport: &urlRewriteTransport{target: "http://127.0.0.1:1", transport: http.DefaultTransport}} + + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected connection error, got nil") + } + if !strings.Contains(err.Error(), "request failed") { + t.Errorf("expected 'request failed' in error, got %v", err) + } +} + +// urlRewriteTransport redirects all requests to a test server URL. +type urlRewriteTransport struct { + target string + transport http.RoundTripper +} + +func (t *urlRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.URL.Scheme = "http" + req.URL.Host = strings.TrimPrefix(t.target, "http://") + return t.transport.RoundTrip(req) +} diff --git a/internal/connector/notifier/slack/slack.go b/internal/connector/notifier/slack/slack.go new file mode 100644 index 0000000..a48016e --- /dev/null +++ b/internal/connector/notifier/slack/slack.go @@ -0,0 +1,92 @@ +package slack + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// Config holds configuration for the Slack notifier. +type Config struct { + // WebhookURL is the Slack incoming webhook URL. + WebhookURL string `json:"webhook_url"` + // ChannelOverride optionally overrides the webhook's default channel. + ChannelOverride string `json:"channel,omitempty"` + // Username optionally sets the bot display name. + Username string `json:"username,omitempty"` + // IconEmoji optionally sets the bot icon (e.g., ":lock:"). + IconEmoji string `json:"icon_emoji,omitempty"` +} + +// Notifier sends notifications to Slack via incoming webhooks. +type Notifier struct { + config Config + httpClient *http.Client +} + +// New creates a new Slack notifier. +func New(config Config) *Notifier { + return &Notifier{ + config: config, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// Channel returns the channel identifier. +func (n *Notifier) Channel() string { + return "Slack" +} + +// Send delivers a notification to Slack via webhook. +func (n *Notifier) Send(ctx context.Context, recipient string, subject string, body string) error { + payload := slackMessage{ + Text: fmt.Sprintf("*%s*\n%s", subject, body), + } + + if n.config.ChannelOverride != "" { + payload.Channel = n.config.ChannelOverride + } + if n.config.Username != "" { + payload.Username = n.config.Username + } + if n.config.IconEmoji != "" { + payload.IconEmoji = n.config.IconEmoji + } + + jsonBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("slack: failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, n.config.WebhookURL, bytes.NewReader(jsonBytes)) + if err != nil { + return fmt.Errorf("slack: failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := n.httpClient.Do(req) + if err != nil { + return fmt.Errorf("slack: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("slack: webhook returned HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +type slackMessage struct { + Text string `json:"text"` + Channel string `json:"channel,omitempty"` + Username string `json:"username,omitempty"` + IconEmoji string `json:"icon_emoji,omitempty"` +} diff --git a/internal/connector/notifier/slack/slack_test.go b/internal/connector/notifier/slack/slack_test.go new file mode 100644 index 0000000..84751eb --- /dev/null +++ b/internal/connector/notifier/slack/slack_test.go @@ -0,0 +1,107 @@ +package slack + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestSlack_Channel(t *testing.T) { + n := New(Config{WebhookURL: "https://hooks.slack.com/test"}) + if n.Channel() != "Slack" { + t.Errorf("expected channel Slack, got %s", n.Channel()) + } +} + +func TestSlack_SendSuccess(t *testing.T) { + var receivedPayload slackMessage + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + if err := json.NewDecoder(r.Body).Decode(&receivedPayload); err != nil { + t.Fatalf("failed to decode payload: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + n := New(Config{WebhookURL: server.URL}) + err := n.Send(context.Background(), "ops@example.com", "Cert Expiring", "mc-api-prod expires in 7 days") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !strings.Contains(receivedPayload.Text, "*Cert Expiring*") { + t.Errorf("expected bold subject in text, got %q", receivedPayload.Text) + } + if !strings.Contains(receivedPayload.Text, "mc-api-prod expires in 7 days") { + t.Errorf("expected body in text, got %q", receivedPayload.Text) + } +} + +func TestSlack_SendWithOverrides(t *testing.T) { + var receivedPayload slackMessage + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&receivedPayload) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + n := New(Config{ + WebhookURL: server.URL, + ChannelOverride: "#alerts", + Username: "certctl-bot", + IconEmoji: ":lock:", + }) + err := n.Send(context.Background(), "", "Test", "body") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedPayload.Channel != "#alerts" { + t.Errorf("expected channel #alerts, got %s", receivedPayload.Channel) + } + if receivedPayload.Username != "certctl-bot" { + t.Errorf("expected username certctl-bot, got %s", receivedPayload.Username) + } + if receivedPayload.IconEmoji != ":lock:" { + t.Errorf("expected icon_emoji :lock:, got %s", receivedPayload.IconEmoji) + } +} + +func TestSlack_SendHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("invalid_token")) + })) + defer server.Close() + + n := New(Config{WebhookURL: server.URL}) + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 403") { + t.Errorf("expected HTTP 403 in error, got %v", err) + } +} + +func TestSlack_SendConnectionError(t *testing.T) { + n := New(Config{WebhookURL: "http://127.0.0.1:1"}) + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected connection error, got nil") + } + if !strings.Contains(err.Error(), "request failed") { + t.Errorf("expected 'request failed' in error, got %v", err) + } +} diff --git a/internal/connector/notifier/teams/teams.go b/internal/connector/notifier/teams/teams.go new file mode 100644 index 0000000..60a7132 --- /dev/null +++ b/internal/connector/notifier/teams/teams.go @@ -0,0 +1,93 @@ +package teams + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// Config holds configuration for the Microsoft Teams notifier. +type Config struct { + // WebhookURL is the Teams incoming webhook URL. + WebhookURL string `json:"webhook_url"` +} + +// Notifier sends notifications to Microsoft Teams via incoming webhooks. +type Notifier struct { + config Config + httpClient *http.Client +} + +// New creates a new Teams notifier. +func New(config Config) *Notifier { + return &Notifier{ + config: config, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// Channel returns the channel identifier. +func (n *Notifier) Channel() string { + return "Teams" +} + +// Send delivers a notification to Teams via webhook using MessageCard format. +func (n *Notifier) Send(ctx context.Context, recipient string, subject string, body string) error { + card := teamsMessageCard{ + Type: "MessageCard", + Context: "https://schema.org/extensions", + ThemeColor: "0076D7", + Summary: subject, + Sections: []teamsSection{ + { + ActivityTitle: subject, + Text: body, + Markdown: true, + }, + }, + } + + jsonBytes, err := json.Marshal(card) + if err != nil { + return fmt.Errorf("teams: failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, n.config.WebhookURL, bytes.NewReader(jsonBytes)) + if err != nil { + return fmt.Errorf("teams: failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := n.httpClient.Do(req) + if err != nil { + return fmt.Errorf("teams: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("teams: webhook returned HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +type teamsMessageCard struct { + Type string `json:"@type"` + Context string `json:"@context"` + ThemeColor string `json:"themeColor"` + Summary string `json:"summary"` + Sections []teamsSection `json:"sections"` +} + +type teamsSection struct { + ActivityTitle string `json:"activityTitle"` + Text string `json:"text"` + Markdown bool `json:"markdown"` +} diff --git a/internal/connector/notifier/teams/teams_test.go b/internal/connector/notifier/teams/teams_test.go new file mode 100644 index 0000000..0f202f5 --- /dev/null +++ b/internal/connector/notifier/teams/teams_test.go @@ -0,0 +1,91 @@ +package teams + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestTeams_Channel(t *testing.T) { + n := New(Config{WebhookURL: "https://outlook.office.com/webhook/test"}) + if n.Channel() != "Teams" { + t.Errorf("expected channel Teams, got %s", n.Channel()) + } +} + +func TestTeams_SendSuccess(t *testing.T) { + var receivedCard teamsMessageCard + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + if err := json.NewDecoder(r.Body).Decode(&receivedCard); err != nil { + t.Fatalf("failed to decode payload: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + n := New(Config{WebhookURL: server.URL}) + err := n.Send(context.Background(), "team@example.com", "Renewal Failed", "Certificate mc-api-prod renewal failed after 3 attempts") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if receivedCard.Type != "MessageCard" { + t.Errorf("expected @type MessageCard, got %s", receivedCard.Type) + } + if receivedCard.Summary != "Renewal Failed" { + t.Errorf("expected summary 'Renewal Failed', got %s", receivedCard.Summary) + } + if receivedCard.ThemeColor != "0076D7" { + t.Errorf("expected theme color 0076D7, got %s", receivedCard.ThemeColor) + } + if len(receivedCard.Sections) != 1 { + t.Fatalf("expected 1 section, got %d", len(receivedCard.Sections)) + } + if receivedCard.Sections[0].ActivityTitle != "Renewal Failed" { + t.Errorf("expected section title 'Renewal Failed', got %s", receivedCard.Sections[0].ActivityTitle) + } + if !strings.Contains(receivedCard.Sections[0].Text, "mc-api-prod") { + t.Errorf("expected body to contain cert ID, got %s", receivedCard.Sections[0].Text) + } + if !receivedCard.Sections[0].Markdown { + t.Error("expected markdown=true in section") + } +} + +func TestTeams_SendHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("bad request")) + })) + defer server.Close() + + n := New(Config{WebhookURL: server.URL}) + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 400") { + t.Errorf("expected HTTP 400 in error, got %v", err) + } +} + +func TestTeams_SendConnectionError(t *testing.T) { + n := New(Config{WebhookURL: "http://127.0.0.1:1"}) + err := n.Send(context.Background(), "", "Test", "body") + if err == nil { + t.Fatal("expected connection error, got nil") + } + if !strings.Contains(err.Error(), "request failed") { + t.Errorf("expected 'request failed' in error, got %v", err) + } +} diff --git a/internal/domain/notification.go b/internal/domain/notification.go index be28c9c..8eae9ad 100644 --- a/internal/domain/notification.go +++ b/internal/domain/notification.go @@ -35,7 +35,10 @@ const ( type NotificationChannel string const ( - NotificationChannelEmail NotificationChannel = "Email" - NotificationChannelWebhook NotificationChannel = "Webhook" - NotificationChannelSlack NotificationChannel = "Slack" + NotificationChannelEmail NotificationChannel = "Email" + NotificationChannelWebhook NotificationChannel = "Webhook" + NotificationChannelSlack NotificationChannel = "Slack" + NotificationChannelTeams NotificationChannel = "Teams" + NotificationChannelPagerDuty NotificationChannel = "PagerDuty" + NotificationChannelOpsGenie NotificationChannel = "OpsGenie" )