diff --git a/cmd/server/main.go b/cmd/server/main.go index 00573f2..4c75059 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -589,7 +589,7 @@ func main() { bodyLimitMiddleware, corsMiddleware, authMiddleware, - auditMiddleware, + auditMiddleware.Middleware, } // Add rate limiter if enabled @@ -606,7 +606,7 @@ func main() { rateLimiter, corsMiddleware, authMiddleware, - auditMiddleware, + auditMiddleware.Middleware, } logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize) } @@ -724,6 +724,17 @@ func main() { logger.Error("HTTP server shutdown error", "error", err) } + // Drain in-flight audit-recording goroutines before closing the DB pool. + // The audit middleware spawns one goroutine per non-excluded request; those + // goroutines run detached from the request context and write to the + // audit_events table via the same *sql.DB. Without this drain, SIGTERM + // would close the DB pool while recordings were mid-flight, silently + // dropping audit events (M-1, CWE-662 / CWE-400). + logger.Info("flushing audit middleware in-flight recordings") + if err := auditMiddleware.Flush(shutdownCtx); err != nil { + logger.Warn("audit middleware flush did not complete in time", "error", err) + } + // Close database connection if err := db.Close(); err != nil { logger.Error("error closing database connection", "error", err) diff --git a/internal/api/middleware/audit.go b/internal/api/middleware/audit.go index 63de2b8..1ab7cb2 100644 --- a/internal/api/middleware/audit.go +++ b/internal/api/middleware/audit.go @@ -4,16 +4,22 @@ import ( "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "log/slog" "net/http" "strings" + "sync" "time" ) // AuditRecorder is the interface that the audit middleware uses to record API calls. // This avoids importing the service package directly, maintaining dependency inversion. +// +// Implementations may perform I/O (e.g., database writes). The middleware invokes +// RecordAPICall from a tracked goroutine so that callers can drain in-flight +// recordings during graceful shutdown via AuditMiddleware.Flush. type AuditRecorder interface { RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error } @@ -26,10 +32,42 @@ type AuditConfig struct { Logger *slog.Logger } -// NewAuditLog creates a middleware that records every API call to the audit trail. -// It captures method, path, authenticated actor, request body hash, response status, and latency. -// Audit recording is best-effort — failures are logged but don't affect the HTTP response. -func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) http.Handler { +// ErrAuditFlushTimeout is returned by AuditMiddleware.Flush when in-flight audit +// recordings do not complete before the provided context is cancelled or its +// deadline elapses. It mirrors scheduler.ErrSchedulerShutdownTimeout so callers +// can branch on graceful-shutdown timeouts consistently across subsystems. +var ErrAuditFlushTimeout = errors.New("audit middleware flush timeout") + +// AuditMiddleware is the handle returned by NewAuditLog. It wraps the audit +// logging HTTP middleware and tracks the goroutines spawned to record each API +// call, so that callers can drain them during graceful shutdown (M-1, CWE-662 +// / CWE-400). The goroutines themselves still run detached from the request +// context — the shutdown-drain signal flows through this struct's WaitGroup +// instead of the per-request context. +type AuditMiddleware struct { + recorder AuditRecorder + logger *slog.Logger + excludeSet map[string]bool + + // wg tracks every audit-recording goroutine spawned by Middleware so Flush + // can block until they complete before the DB pool is torn down. + wg sync.WaitGroup +} + +// NewAuditLog constructs the API audit logging middleware. The returned +// *AuditMiddleware exposes the HTTP middleware via the Middleware method value +// (same func(http.Handler) http.Handler shape) and a Flush method that the +// process shutdown path must call after the HTTP server has stopped accepting +// new requests but before the audit recorder's backing store (e.g., the +// database connection pool) is closed. +// +// The middleware records method, path, authenticated actor, request body hash, +// response status, and latency. Recording is best-effort — individual failures +// are logged and do not affect the HTTP response. Shutdown is NOT best-effort: +// Flush must succeed (or time out, returning ErrAuditFlushTimeout) so that +// in-flight events are not lost when the audit recorder's connection pool is +// closed out from under the goroutines. +func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) *AuditMiddleware { excludeSet := make(map[string]bool, len(cfg.ExcludePaths)) for _, p := range cfg.ExcludePaths { excludeSet[p] = true @@ -40,68 +78,122 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt logger = slog.Default() } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Skip excluded paths (health, readiness probes) - for prefix := range excludeSet { - if strings.HasPrefix(r.URL.Path, prefix) { - next.ServeHTTP(w, r) - return - } + return &AuditMiddleware{ + recorder: recorder, + logger: logger, + excludeSet: excludeSet, + } +} + +// Middleware is the http.Handler wrapper. It has the standard +// func(http.Handler) http.Handler middleware signature so it can be composed +// into an existing middleware chain via a method value (auditMiddleware.Middleware). +func (a *AuditMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip excluded paths (health, readiness probes) + for prefix := range a.excludeSet { + if strings.HasPrefix(r.URL.Path, prefix) { + next.ServeHTTP(w, r) + return } + } - start := time.Now() + start := time.Now() - // Hash request body for audit (don't store raw bodies — security + size concerns) - bodyHash := "" - if r.Body != nil && r.Body != http.NoBody { - hasher := sha256.New() - body, err := io.ReadAll(r.Body) - if err == nil && len(body) > 0 { - hasher.Write(body) - bodyHash = hex.EncodeToString(hasher.Sum(nil))[:16] // truncated hash - // Restore the body for downstream handlers - r.Body = io.NopCloser(strings.NewReader(string(body))) - } + // Hash request body for audit (don't store raw bodies — security + size concerns) + bodyHash := "" + if r.Body != nil && r.Body != http.NoBody { + hasher := sha256.New() + body, err := io.ReadAll(r.Body) + if err == nil && len(body) > 0 { + hasher.Write(body) + bodyHash = hex.EncodeToString(hasher.Sum(nil))[:16] // truncated hash + // Restore the body for downstream handlers + r.Body = io.NopCloser(strings.NewReader(string(body))) } + } - // Extract actor from auth context - actor := "anonymous" - if user, ok := GetUser(r.Context()); ok && user != "" { - actor = user + // Extract actor from auth context + actor := "anonymous" + if user, ok := GetUser(r.Context()); ok && user != "" { + actor = user + } + + // Wrap response writer to capture status code + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + next.ServeHTTP(wrapped, r) + + latency := time.Since(start).Milliseconds() + + // Snapshot request-derived inputs so the goroutine does not race with + // the http.Server reusing r after this handler returns. + method := r.Method + path := r.URL.Path + status := wrapped.statusCode + + // Record audit event asynchronously (best-effort, don't block response). + // SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI) + // to prevent query parameters from being recorded in the immutable audit trail. + // Query strings may contain cursor tokens, API keys passed as params, or other + // sensitive filter values. Since the audit trail is append-only with no deletion + // capability, any sensitive data recorded would persist permanently. + // + // The goroutine is tracked in a.wg so AuditMiddleware.Flush can drain + // in-flight recordings during graceful shutdown. Without this (M-1, + // CWE-662 / CWE-400), SIGTERM would close the DB pool while recordings + // were still mid-flight, silently dropping audit events. + a.wg.Add(1) + go func() { + defer a.wg.Done() + if err := a.recorder.RecordAPICall( + context.Background(), + method, + path, + actor, + bodyHash, + status, + latency, + ); err != nil { + a.logger.Error("failed to record API audit event", + "error", err, + "method", method, + "path", path, + ) } + }() + }) +} - // Wrap response writer to capture status code - wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} +// Flush blocks until every audit-recording goroutine spawned by Middleware has +// completed, or until ctx is cancelled / its deadline elapses. It must be +// called from the process shutdown path after http.Server.Shutdown has +// returned (so no new requests are being accepted) but before the backing +// audit recorder's resources (DB pool, etc.) are torn down. +// +// On timeout or cancellation Flush returns ErrAuditFlushTimeout wrapped with +// any context error; in-flight goroutines continue to run and may still write +// to the recorder once they unblock — the caller is responsible for deciding +// whether to proceed with teardown anyway or surface the error. +// +// Flush mirrors the idiom used by scheduler.Scheduler.WaitForCompletion so +// that the two subsystems drain identically at shutdown. +func (a *AuditMiddleware) Flush(ctx context.Context) error { + done := make(chan struct{}) + go func() { + a.wg.Wait() + close(done) + }() - next.ServeHTTP(wrapped, r) - - latency := time.Since(start).Milliseconds() - - // Record audit event asynchronously (best-effort, don't block response). - // SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI) - // to prevent query parameters from being recorded in the immutable audit trail. - // Query strings may contain cursor tokens, API keys passed as params, or other - // sensitive filter values. Since the audit trail is append-only with no deletion - // capability, any sensitive data recorded would persist permanently. - go func() { - if err := recorder.RecordAPICall( - context.Background(), - r.Method, - r.URL.Path, - actor, - bodyHash, - wrapped.statusCode, - latency, - ); err != nil { - logger.Error("failed to record API audit event", - "error", err, - "method", r.Method, - "path", r.URL.Path, - ) - } - }() - }) + select { + case <-done: + a.logger.Info("audit middleware flush complete") + return nil + case <-ctx.Done(): + a.logger.Warn("audit middleware flush did not complete before context cancellation", + "error", ctx.Err(), + ) + return fmt.Errorf("%w: %w", ErrAuditFlushTimeout, ctx.Err()) } } diff --git a/internal/api/middleware/audit_test.go b/internal/api/middleware/audit_test.go index dbe41f6..90a7924 100644 --- a/internal/api/middleware/audit_test.go +++ b/internal/api/middleware/audit_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "net/http" @@ -16,7 +17,8 @@ import ( type mockAuditRecorder struct { mu sync.Mutex calls []auditCall - err error // if non-nil, RecordAPICall returns this + err error // if non-nil, RecordAPICall returns this + block chan struct{} // if non-nil, RecordAPICall blocks on receive before returning } type auditCall struct { @@ -29,6 +31,13 @@ type auditCall struct { } func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error { + // Optional: block the recorder until a signal is received so tests can + // exercise the shutdown-drain path deterministically. The block happens + // before any state mutation so Flush-timeout tests see the call + // "in-flight" (wg counter > 0) with no recorded entries yet. + if m.block != nil { + <-m.block + } m.mu.Lock() defer m.mu.Unlock() m.calls = append(m.calls, auditCall{ @@ -90,7 +99,7 @@ func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool { func TestAuditLog_RecordsAPICall(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -130,7 +139,7 @@ func TestAuditLog_RecordsAPICall(t *testing.T) { func TestAuditLog_CapturesStatusCode(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) @@ -157,7 +166,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) { recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{ ExcludePaths: []string{"/health", "/ready"}, - }) + }).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -193,7 +202,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) { func TestAuditLog_HashesRequestBody(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware // Handler verifies body was restored handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -228,7 +237,7 @@ func TestAuditLog_HashesRequestBody(t *testing.T) { func TestAuditLog_EmptyBodyNoHash(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -253,7 +262,7 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) { func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -285,7 +294,7 @@ func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) { recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")} - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -304,7 +313,7 @@ func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) { func TestAuditLog_CapturesLatency(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Millisecond) @@ -330,7 +339,7 @@ func TestAuditLog_CapturesLatency(t *testing.T) { func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) { recorder := newWaitableAuditRecorder() - mw := NewAuditLog(recorder, AuditConfig{}) + mw := NewAuditLog(recorder, AuditConfig{}).Middleware handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -429,3 +438,112 @@ func TestAuditServiceAdapter_PropagatesError(t *testing.T) { t.Errorf("expected database error, got %v", err) } } + +// TestAuditLog_FlushDrainsInFlightGoroutines verifies the M-1 shutdown-drain +// contract: Flush blocks until every audit-recording goroutine spawned by the +// middleware completes, then returns nil. Without the drain (pre-M-1 code), +// the DB pool would be closed while in-flight goroutines were still calling +// RecordAPICall, silently dropping audit events (CWE-662 / CWE-400). +func TestAuditLog_FlushDrainsInFlightGoroutines(t *testing.T) { + // Recorder blocks on `unblock` until the test releases it. This simulates + // a slow DB write still in flight when shutdown begins. + unblock := make(chan struct{}) + recorder := &mockAuditRecorder{block: unblock} + auditMW := NewAuditLog(recorder, AuditConfig{}) + + handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Fire a request. Handler returns immediately; recorder goroutine is + // parked on the `unblock` channel inside RecordAPICall. + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Start Flush in a goroutine — it must block on the WaitGroup until we + // release the recorder. + flushDone := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + flushDone <- auditMW.Flush(ctx) + }() + + // Confirm Flush is actually blocked (not returning immediately). + select { + case err := <-flushDone: + t.Fatalf("Flush returned before recorder unblocked: err=%v", err) + case <-time.After(50 * time.Millisecond): + // expected: Flush is blocked on wg.Wait + } + + // Release the recorder. Flush should now observe wg counter drop to 0 + // and return nil. + close(unblock) + + select { + case err := <-flushDone: + if err != nil { + t.Fatalf("expected nil from Flush after drain, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Flush did not return after recorder unblocked") + } + + // Verify the audit event was actually recorded (i.e., the goroutine + // completed its write — not just that Flush unblocked). + calls := recorder.getCalls() + if len(calls) != 1 { + t.Fatalf("expected 1 recorded audit call, got %d", len(calls)) + } + if calls[0].Path != "/api/v1/certificates" { + t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path) + } +} + +// TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout verifies that Flush +// respects its context: when in-flight goroutines exceed the shutdown budget, +// Flush returns an error wrapping ErrAuditFlushTimeout plus ctx.Err(). The +// caller can then decide whether to proceed with teardown anyway. +func TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout(t *testing.T) { + // Recorder will never unblock on its own — we unblock at end of test for + // a clean race-safe teardown. + unblock := make(chan struct{}) + recorder := &mockAuditRecorder{block: unblock} + auditMW := NewAuditLog(recorder, AuditConfig{}) + + handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Flush with a tiny deadline — must time out. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + err := auditMW.Flush(ctx) + + if err == nil { + // Release the blocked goroutine before failing so the race detector + // doesn't trip on teardown. + close(unblock) + t.Fatal("expected Flush to return an error on timeout, got nil") + } + if !errors.Is(err, ErrAuditFlushTimeout) { + close(unblock) + t.Fatalf("expected error to wrap ErrAuditFlushTimeout, got %v", err) + } + if !errors.Is(err, context.DeadlineExceeded) { + close(unblock) + t.Fatalf("expected error to wrap context.DeadlineExceeded, got %v", err) + } + + // Race-safe teardown: unblock the recorder goroutine so it exits cleanly + // before the test returns. The goroutine itself is still detached and + // will record to the mock even after Flush timed out — that's the + // documented behavior (Flush surfaces the timeout; caller decides). + close(unblock) +}