Merge branch 'fix/m1-audit-shutdown-drain'

Resolves M-1 (Medium): Audit recorder shutdown drain.

The API audit middleware's detached recording goroutines now drain
during graceful shutdown via AuditMiddleware.Flush (sync.WaitGroup +
timeout-aware select), called between http.Server.Shutdown and
db.Close. Prevents silent audit-event loss on SIGTERM
(CWE-662 / CWE-400).
This commit is contained in:
Shankar
2026-04-17 17:29:54 +00:00
3 changed files with 291 additions and 70 deletions
+13 -2
View File
@@ -589,7 +589,7 @@ func main() {
bodyLimitMiddleware,
corsMiddleware,
authMiddleware,
auditMiddleware,
auditMiddleware.Middleware,
}
// Add rate limiter if enabled
@@ -606,7 +606,7 @@ func main() {
rateLimiter,
corsMiddleware,
authMiddleware,
auditMiddleware,
auditMiddleware.Middleware,
}
logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize)
}
@@ -724,6 +724,17 @@ func main() {
logger.Error("HTTP server shutdown error", "error", err)
}
// Drain in-flight audit-recording goroutines before closing the DB pool.
// The audit middleware spawns one goroutine per non-excluded request; those
// goroutines run detached from the request context and write to the
// audit_events table via the same *sql.DB. Without this drain, SIGTERM
// would close the DB pool while recordings were mid-flight, silently
// dropping audit events (M-1, CWE-662 / CWE-400).
logger.Info("flushing audit middleware in-flight recordings")
if err := auditMiddleware.Flush(shutdownCtx); err != nil {
logger.Warn("audit middleware flush did not complete in time", "error", err)
}
// Close database connection
if err := db.Close(); err != nil {
logger.Error("error closing database connection", "error", err)
+150 -58
View File
@@ -4,16 +4,22 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"sync"
"time"
)
// AuditRecorder is the interface that the audit middleware uses to record API calls.
// This avoids importing the service package directly, maintaining dependency inversion.
//
// Implementations may perform I/O (e.g., database writes). The middleware invokes
// RecordAPICall from a tracked goroutine so that callers can drain in-flight
// recordings during graceful shutdown via AuditMiddleware.Flush.
type AuditRecorder interface {
RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error
}
@@ -26,10 +32,42 @@ type AuditConfig struct {
Logger *slog.Logger
}
// NewAuditLog creates a middleware that records every API call to the audit trail.
// It captures method, path, authenticated actor, request body hash, response status, and latency.
// Audit recording is best-effort — failures are logged but don't affect the HTTP response.
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) http.Handler {
// ErrAuditFlushTimeout is returned by AuditMiddleware.Flush when in-flight audit
// recordings do not complete before the provided context is cancelled or its
// deadline elapses. It mirrors scheduler.ErrSchedulerShutdownTimeout so callers
// can branch on graceful-shutdown timeouts consistently across subsystems.
var ErrAuditFlushTimeout = errors.New("audit middleware flush timeout")
// AuditMiddleware is the handle returned by NewAuditLog. It wraps the audit
// logging HTTP middleware and tracks the goroutines spawned to record each API
// call, so that callers can drain them during graceful shutdown (M-1, CWE-662
// / CWE-400). The goroutines themselves still run detached from the request
// context — the shutdown-drain signal flows through this struct's WaitGroup
// instead of the per-request context.
type AuditMiddleware struct {
recorder AuditRecorder
logger *slog.Logger
excludeSet map[string]bool
// wg tracks every audit-recording goroutine spawned by Middleware so Flush
// can block until they complete before the DB pool is torn down.
wg sync.WaitGroup
}
// NewAuditLog constructs the API audit logging middleware. The returned
// *AuditMiddleware exposes the HTTP middleware via the Middleware method value
// (same func(http.Handler) http.Handler shape) and a Flush method that the
// process shutdown path must call after the HTTP server has stopped accepting
// new requests but before the audit recorder's backing store (e.g., the
// database connection pool) is closed.
//
// The middleware records method, path, authenticated actor, request body hash,
// response status, and latency. Recording is best-effort — individual failures
// are logged and do not affect the HTTP response. Shutdown is NOT best-effort:
// Flush must succeed (or time out, returning ErrAuditFlushTimeout) so that
// in-flight events are not lost when the audit recorder's connection pool is
// closed out from under the goroutines.
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) *AuditMiddleware {
excludeSet := make(map[string]bool, len(cfg.ExcludePaths))
for _, p := range cfg.ExcludePaths {
excludeSet[p] = true
@@ -40,68 +78,122 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
logger = slog.Default()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip excluded paths (health, readiness probes)
for prefix := range excludeSet {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
return &AuditMiddleware{
recorder: recorder,
logger: logger,
excludeSet: excludeSet,
}
}
// Middleware is the http.Handler wrapper. It has the standard
// func(http.Handler) http.Handler middleware signature so it can be composed
// into an existing middleware chain via a method value (auditMiddleware.Middleware).
func (a *AuditMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip excluded paths (health, readiness probes)
for prefix := range a.excludeSet {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
start := time.Now()
start := time.Now()
// Hash request body for audit (don't store raw bodies — security + size concerns)
bodyHash := ""
if r.Body != nil && r.Body != http.NoBody {
hasher := sha256.New()
body, err := io.ReadAll(r.Body)
if err == nil && len(body) > 0 {
hasher.Write(body)
bodyHash = hex.EncodeToString(hasher.Sum(nil))[:16] // truncated hash
// Restore the body for downstream handlers
r.Body = io.NopCloser(strings.NewReader(string(body)))
}
// Hash request body for audit (don't store raw bodies — security + size concerns)
bodyHash := ""
if r.Body != nil && r.Body != http.NoBody {
hasher := sha256.New()
body, err := io.ReadAll(r.Body)
if err == nil && len(body) > 0 {
hasher.Write(body)
bodyHash = hex.EncodeToString(hasher.Sum(nil))[:16] // truncated hash
// Restore the body for downstream handlers
r.Body = io.NopCloser(strings.NewReader(string(body)))
}
}
// Extract actor from auth context
actor := "anonymous"
if user, ok := GetUser(r.Context()); ok && user != "" {
actor = user
// Extract actor from auth context
actor := "anonymous"
if user, ok := GetUser(r.Context()); ok && user != "" {
actor = user
}
// Wrap response writer to capture status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
latency := time.Since(start).Milliseconds()
// Snapshot request-derived inputs so the goroutine does not race with
// the http.Server reusing r after this handler returns.
method := r.Method
path := r.URL.Path
status := wrapped.statusCode
// Record audit event asynchronously (best-effort, don't block response).
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
// to prevent query parameters from being recorded in the immutable audit trail.
// Query strings may contain cursor tokens, API keys passed as params, or other
// sensitive filter values. Since the audit trail is append-only with no deletion
// capability, any sensitive data recorded would persist permanently.
//
// The goroutine is tracked in a.wg so AuditMiddleware.Flush can drain
// in-flight recordings during graceful shutdown. Without this (M-1,
// CWE-662 / CWE-400), SIGTERM would close the DB pool while recordings
// were still mid-flight, silently dropping audit events.
a.wg.Add(1)
go func() {
defer a.wg.Done()
if err := a.recorder.RecordAPICall(
context.Background(),
method,
path,
actor,
bodyHash,
status,
latency,
); err != nil {
a.logger.Error("failed to record API audit event",
"error", err,
"method", method,
"path", path,
)
}
}()
})
}
// Wrap response writer to capture status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Flush blocks until every audit-recording goroutine spawned by Middleware has
// completed, or until ctx is cancelled / its deadline elapses. It must be
// called from the process shutdown path after http.Server.Shutdown has
// returned (so no new requests are being accepted) but before the backing
// audit recorder's resources (DB pool, etc.) are torn down.
//
// On timeout or cancellation Flush returns ErrAuditFlushTimeout wrapped with
// any context error; in-flight goroutines continue to run and may still write
// to the recorder once they unblock — the caller is responsible for deciding
// whether to proceed with teardown anyway or surface the error.
//
// Flush mirrors the idiom used by scheduler.Scheduler.WaitForCompletion so
// that the two subsystems drain identically at shutdown.
func (a *AuditMiddleware) Flush(ctx context.Context) error {
done := make(chan struct{})
go func() {
a.wg.Wait()
close(done)
}()
next.ServeHTTP(wrapped, r)
latency := time.Since(start).Milliseconds()
// Record audit event asynchronously (best-effort, don't block response).
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
// to prevent query parameters from being recorded in the immutable audit trail.
// Query strings may contain cursor tokens, API keys passed as params, or other
// sensitive filter values. Since the audit trail is append-only with no deletion
// capability, any sensitive data recorded would persist permanently.
go func() {
if err := recorder.RecordAPICall(
context.Background(),
r.Method,
r.URL.Path,
actor,
bodyHash,
wrapped.statusCode,
latency,
); err != nil {
logger.Error("failed to record API audit event",
"error", err,
"method", r.Method,
"path", r.URL.Path,
)
}
}()
})
select {
case <-done:
a.logger.Info("audit middleware flush complete")
return nil
case <-ctx.Done():
a.logger.Warn("audit middleware flush did not complete before context cancellation",
"error", ctx.Err(),
)
return fmt.Errorf("%w: %w", ErrAuditFlushTimeout, ctx.Err())
}
}
+128 -10
View File
@@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"net/http"
@@ -16,7 +17,8 @@ import (
type mockAuditRecorder struct {
mu sync.Mutex
calls []auditCall
err error // if non-nil, RecordAPICall returns this
err error // if non-nil, RecordAPICall returns this
block chan struct{} // if non-nil, RecordAPICall blocks on receive before returning
}
type auditCall struct {
@@ -29,6 +31,13 @@ type auditCall struct {
}
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
// Optional: block the recorder until a signal is received so tests can
// exercise the shutdown-drain path deterministically. The block happens
// before any state mutation so Flush-timeout tests see the call
// "in-flight" (wg counter > 0) with no recorded entries yet.
if m.block != nil {
<-m.block
}
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, auditCall{
@@ -90,7 +99,7 @@ func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool {
func TestAuditLog_RecordsAPICall(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -130,7 +139,7 @@ func TestAuditLog_RecordsAPICall(t *testing.T) {
func TestAuditLog_CapturesStatusCode(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
@@ -157,7 +166,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{
ExcludePaths: []string{"/health", "/ready"},
})
}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -193,7 +202,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
func TestAuditLog_HashesRequestBody(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
// Handler verifies body was restored
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -228,7 +237,7 @@ func TestAuditLog_HashesRequestBody(t *testing.T) {
func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -253,7 +262,7 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -285,7 +294,7 @@ func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -304,7 +313,7 @@ func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
func TestAuditLog_CapturesLatency(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond)
@@ -330,7 +339,7 @@ func TestAuditLog_CapturesLatency(t *testing.T) {
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{})
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -429,3 +438,112 @@ func TestAuditServiceAdapter_PropagatesError(t *testing.T) {
t.Errorf("expected database error, got %v", err)
}
}
// TestAuditLog_FlushDrainsInFlightGoroutines verifies the M-1 shutdown-drain
// contract: Flush blocks until every audit-recording goroutine spawned by the
// middleware completes, then returns nil. Without the drain (pre-M-1 code),
// the DB pool would be closed while in-flight goroutines were still calling
// RecordAPICall, silently dropping audit events (CWE-662 / CWE-400).
func TestAuditLog_FlushDrainsInFlightGoroutines(t *testing.T) {
// Recorder blocks on `unblock` until the test releases it. This simulates
// a slow DB write still in flight when shutdown begins.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Fire a request. Handler returns immediately; recorder goroutine is
// parked on the `unblock` channel inside RecordAPICall.
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Start Flush in a goroutine — it must block on the WaitGroup until we
// release the recorder.
flushDone := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
flushDone <- auditMW.Flush(ctx)
}()
// Confirm Flush is actually blocked (not returning immediately).
select {
case err := <-flushDone:
t.Fatalf("Flush returned before recorder unblocked: err=%v", err)
case <-time.After(50 * time.Millisecond):
// expected: Flush is blocked on wg.Wait
}
// Release the recorder. Flush should now observe wg counter drop to 0
// and return nil.
close(unblock)
select {
case err := <-flushDone:
if err != nil {
t.Fatalf("expected nil from Flush after drain, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Flush did not return after recorder unblocked")
}
// Verify the audit event was actually recorded (i.e., the goroutine
// completed its write — not just that Flush unblocked).
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 recorded audit call, got %d", len(calls))
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
}
// TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout verifies that Flush
// respects its context: when in-flight goroutines exceed the shutdown budget,
// Flush returns an error wrapping ErrAuditFlushTimeout plus ctx.Err(). The
// caller can then decide whether to proceed with teardown anyway.
func TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout(t *testing.T) {
// Recorder will never unblock on its own — we unblock at end of test for
// a clean race-safe teardown.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Flush with a tiny deadline — must time out.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := auditMW.Flush(ctx)
if err == nil {
// Release the blocked goroutine before failing so the race detector
// doesn't trip on teardown.
close(unblock)
t.Fatal("expected Flush to return an error on timeout, got nil")
}
if !errors.Is(err, ErrAuditFlushTimeout) {
close(unblock)
t.Fatalf("expected error to wrap ErrAuditFlushTimeout, got %v", err)
}
if !errors.Is(err, context.DeadlineExceeded) {
close(unblock)
t.Fatalf("expected error to wrap context.DeadlineExceeded, got %v", err)
}
// Race-safe teardown: unblock the recorder goroutine so it exits cleanly
// before the test returns. The goroutine itself is still detached and
// will record to the mock even after Flush timed out — that's the
// documented behavior (Flush surfaces the timeout; caller decides).
close(unblock)
}