diff --git a/internal/api/middleware/audit_test.go b/internal/api/middleware/audit_test.go index 400c568..57ee1ce 100644 --- a/internal/api/middleware/audit_test.go +++ b/internal/api/middleware/audit_test.go @@ -50,8 +50,46 @@ func (m *mockAuditRecorder) getCalls() []auditCall { return out } +// waitableAuditRecorder wraps a mockAuditRecorder and signals when a recording completes. +// This allows tests to synchronously wait for async audit records without using time.Sleep. +type waitableAuditRecorder struct { + inner *mockAuditRecorder + recorded chan struct{} +} + +func newWaitableAuditRecorder() *waitableAuditRecorder { + return &waitableAuditRecorder{ + inner: &mockAuditRecorder{}, + recorded: make(chan struct{}, 100), // buffered to avoid blocking + } +} + +func (w *waitableAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error { + err := w.inner.RecordAPICall(ctx, method, path, actor, bodyHash, status, latencyMs) + // Signal that a recording was completed + select { + case w.recorded <- struct{}{}: + default: + } + return err +} + +func (w *waitableAuditRecorder) getCalls() []auditCall { + return w.inner.getCalls() +} + +// Wait blocks until a recording is signaled or timeout expires. Returns true if recording completed, false on timeout. +func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool { + select { + case <-w.recorded: + return true + case <-time.After(timeout): + return false + } +} + func TestAuditLog_RecordsAPICall(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -67,8 +105,10 @@ func TestAuditLog_RecordsAPICall(t *testing.T) { t.Fatalf("expected 200, got %d", rr.Code) } - // Audit recording is async — give goroutine time to complete - time.Sleep(50 * time.Millisecond) + // Audit recording is async — wait for goroutine to complete + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -89,7 +129,7 @@ func TestAuditLog_RecordsAPICall(t *testing.T) { } func TestAuditLog_CapturesStatusCode(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -100,7 +140,9 @@ func TestAuditLog_CapturesStatusCode(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -112,7 +154,7 @@ func TestAuditLog_CapturesStatusCode(t *testing.T) { } func TestAuditLog_ExcludesHealth(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{ ExcludePaths: []string{"/health", "/ready"}, }) @@ -136,7 +178,9 @@ func TestAuditLog_ExcludesHealth(t *testing.T) { rr3 := httptest.NewRecorder() handler.ServeHTTP(rr3, req3) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -148,7 +192,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) { } func TestAuditLog_HashesRequestBody(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) // Handler verifies body was restored @@ -165,7 +209,9 @@ func TestAuditLog_HashesRequestBody(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -181,7 +227,7 @@ func TestAuditLog_HashesRequestBody(t *testing.T) { } func TestAuditLog_EmptyBodyNoHash(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -192,7 +238,9 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -204,7 +252,7 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) { } func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -219,7 +267,9 @@ func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 { @@ -253,7 +303,7 @@ func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) { } func TestAuditLog_CapturesLatency(t *testing.T) { - recorder := &mockAuditRecorder{} + recorder := newWaitableAuditRecorder() mw := NewAuditLog(recorder, AuditConfig{}) handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -265,7 +315,9 @@ func TestAuditLog_CapturesLatency(t *testing.T) { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - time.Sleep(50 * time.Millisecond) + if !recorder.Wait(1 * time.Second) { + t.Fatal("timeout waiting for audit record") + } calls := recorder.getCalls() if len(calls) != 1 {