Files
certctl/internal/api/middleware/audit_test.go
T
shankar0123 9b0ff37973 feat: M19 API audit log + M16a notifier connectors (Slack, Teams, PagerDuty, OpsGenie)
M19: HTTP middleware records every API call to the immutable audit trail
with method, path, actor, SHA-256 body hash, status, and latency. Best-effort
async recording via goroutine. Health/ready probes excluded.

M16a: Four pluggable notifier connectors — Slack (incoming webhook), Teams
(MessageCard), PagerDuty (Events API v2), OpsGenie (Alert API v2). Each
enabled by config env var. 30 new tests across middleware and connectors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 17:58:14 -04:00

340 lines
9.8 KiB
Go

package middleware
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// mockAuditRecorder captures RecordAPICall invocations for testing.
type mockAuditRecorder struct {
mu sync.Mutex
calls []auditCall
err error // if non-nil, RecordAPICall returns this
}
type auditCall struct {
Method string
Path string
Actor string
BodyHash string
Status int
LatencyMs int64
}
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, auditCall{
Method: method,
Path: path,
Actor: actor,
BodyHash: bodyHash,
Status: status,
LatencyMs: latencyMs,
})
return m.err
}
func (m *mockAuditRecorder) getCalls() []auditCall {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]auditCall, len(m.calls))
copy(out, m.calls)
return out
}
func TestAuditLog_RecordsAPICall(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
// Audit recording is async — give goroutine time to complete
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Method != "GET" {
t.Errorf("expected method GET, got %s", calls[0].Method)
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
if calls[0].Actor != "anonymous" {
t.Errorf("expected actor anonymous, got %s", calls[0].Actor)
}
if calls[0].Status != 200 {
t.Errorf("expected status 200, got %d", calls[0].Status)
}
}
func TestAuditLog_CapturesStatusCode(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certs/mc-nonexistent", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Status != 404 {
t.Errorf("expected status 404, got %d", calls[0].Status)
}
}
func TestAuditLog_ExcludesHealth(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{
ExcludePaths: []string{"/health", "/ready"},
})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Health endpoint — should be excluded
req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Ready endpoint — should be excluded
req2 := httptest.NewRequest(http.MethodGet, "/ready", nil)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
// API endpoint — should be recorded
req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call (health/ready excluded), got %d", len(calls))
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
}
func TestAuditLog_HashesRequestBody(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
// Handler verifies body was restored
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
if string(body) != `{"name":"test"}` {
t.Errorf("body was not restored: got %q", string(body))
}
w.WriteHeader(http.StatusCreated)
}))
body := strings.NewReader(`{"name":"test"}`)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", body)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
// Body hash should be a 16-char hex string (truncated SHA-256)
if len(calls[0].BodyHash) != 16 {
t.Errorf("expected 16-char body hash, got %q (len=%d)", calls[0].BodyHash, len(calls[0].BodyHash))
}
if calls[0].Status != 201 {
t.Errorf("expected status 201, got %d", calls[0].Status)
}
}
func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].BodyHash != "" {
t.Errorf("expected empty body hash for GET, got %q", calls[0].BodyHash)
}
}
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/mc-1", nil)
// Simulate auth middleware having set the user in context
ctx := context.WithValue(req.Context(), UserKey{}, "api-key-user")
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Actor != "api-key-user" {
t.Errorf("expected actor api-key-user, got %s", calls[0].Actor)
}
if calls[0].Method != "DELETE" {
t.Errorf("expected method DELETE, got %s", calls[0].Method)
}
}
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/stats/summary", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Response should still be 200 even though audit recording fails
if rr.Code != http.StatusOK {
t.Errorf("expected 200 despite recorder error, got %d", rr.Code)
}
}
func TestAuditLog_CapturesLatency(t *testing.T) {
recorder := &mockAuditRecorder{}
mw := NewAuditLog(recorder, AuditConfig{})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
time.Sleep(50 * time.Millisecond)
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].LatencyMs < 10 {
t.Errorf("expected latency >= 10ms, got %dms", calls[0].LatencyMs)
}
}
func TestAuditServiceAdapter_TranslatesCallToEvent(t *testing.T) {
var capturedActor, capturedActorType, capturedAction, capturedResourceType, capturedResourceID string
var capturedDetails map[string]interface{}
adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error {
capturedActor = actor
capturedActorType = actorType
capturedAction = action
capturedResourceType = resourceType
capturedResourceID = resourceID
capturedDetails = details
return nil
})
err := adapter.RecordAPICall(context.Background(), "POST", "/api/v1/certificates", "admin", "abc123", 201, 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedActor != "admin" {
t.Errorf("expected actor admin, got %s", capturedActor)
}
if capturedActorType != "User" {
t.Errorf("expected actorType User, got %s", capturedActorType)
}
if capturedAction != "api_post" {
t.Errorf("expected action api_post, got %s", capturedAction)
}
if capturedResourceType != "api" {
t.Errorf("expected resourceType api, got %s", capturedResourceType)
}
if capturedResourceID != "/api/v1/certificates" {
t.Errorf("expected resourceID /api/v1/certificates, got %s", capturedResourceID)
}
if capturedDetails["method"] != "POST" {
t.Errorf("expected details.method POST, got %v", capturedDetails["method"])
}
if capturedDetails["status"] != 201 {
t.Errorf("expected details.status 201, got %v", capturedDetails["status"])
}
if capturedDetails["latency_ms"] != int64(42) {
t.Errorf("expected details.latency_ms 42, got %v", capturedDetails["latency_ms"])
}
if capturedDetails["body_hash"] != "abc123" {
t.Errorf("expected details.body_hash abc123, got %v", capturedDetails["body_hash"])
}
}
func TestAuditServiceAdapter_PropagatesError(t *testing.T) {
adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error {
return fmt.Errorf("database error")
})
err := adapter.RecordAPICall(context.Background(), "GET", "/api/v1/agents", "user", "", 200, 5)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "database error") {
t.Errorf("expected database error, got %v", err)
}
}