Files
certctl/internal/api/middleware/audit_test.go
T
shankar0123 630831aeac harden(audit+session): full SHA-256 audit hash + cookie segment length cap (MED-15 + Nit-4)
Audit 2026-05-10 Fix 13 Phase F + Fix 14 Phase F partial — close
MED-15 + Nit-4. Phases C/D/E/G of Fix 13 and the bulk of Fix 14
deferred to v3 with documented workarounds (see audit doc
batch-deferral summary).

MED-15: internal/api/middleware/audit.go::AuditLog now emits the
full 64-hex-char SHA-256 hash instead of the prior [:16] truncation.
The audit_events.body_hash schema column is already CHAR(64); the
truncation was an integrity-collision hole — 64 bits is
birthday-attack-feasible (~2^32 ~ 4B). Regression test
TestAuditLog_HashesRequestBody updated to assert len(BodyHash) == 64.

Nit-4: internal/auth/session/service.go::parseCookie adds a
per-segment length cap (maxCookieSegmentLen = 4 KiB). Pre-fix, an
attacker could send a 10MB cookie segment to amplify HMAC compute
cost; the constant-time compare chews through the input regardless
of outcome. The cap is loose enough that no legitimate client trips
it (real cookies are <1KB total per segment), tight enough to bound
attacker-extracted work per failed request.

Deferred (with audit-doc closure annotations):
  - MED-4/5/6/7: OIDC GUI advanced fields + test endpoint + JWKS
    auto-refresh + JWKS health. v3 OIDC-operator-experience bundle.
    Workarounds documented.
  - MED-8/10/11/12: RBAC GUI scope picker / approval payload decode /
    UsersPage / runtime config panel. v3 GUI-polish bundle. Backend
    already accepts the scope_type/scope_id fields; the gap is GUI.
  - MED-13: MCP tools for approvals / break-glass / bootstrap.
    v3 MCP-expansion bundle.
  - MED-14: __Host- cookie rename. Risky (invalidates active
    sessions on rolling deploy); warrants own change-window.
  - MED-16/17: Pre-login UA/IP binding + RFC 9207 iss URL check.
    v3 OIDC-hardening bundle.
  - All 12 LOWs + 4 of 5 Nits: v3 cleanup bundle.

Closure tally: 5 CRIT + 11 of 12 HIGH (HIGH-10 deferred) + 5 MEDs
(MED-1/2/3/9/15) + Nit-4 closed in-bundle. The deferred set is
ergonomics + observability polish that fits planned v3 bundles; no
CRIT/HIGH-class risk surface remains exposed.

Refs: cowork/auth-bundles-audit-2026-05-10.md MED-15, Nit-4
Spec: cowork/auth-bundles-fixes-2026-05-10/13-med-bundle.md Phase F
      cowork/auth-bundles-fixes-2026-05-10/14-low-nit-cleanup.md Phase F
2026-05-10 22:02:26 +00:00

557 lines
18 KiB
Go

package middleware
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/certctl-io/certctl/internal/auth"
)
// mockAuditRecorder captures RecordAPICall invocations for testing.
type mockAuditRecorder struct {
mu sync.Mutex
calls []auditCall
err error // if non-nil, RecordAPICall returns this
block chan struct{} // if non-nil, RecordAPICall blocks on receive before returning
}
type auditCall struct {
Method string
Path string
Actor string
BodyHash string
Status int
LatencyMs int64
}
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
// Optional: block the recorder until a signal is received so tests can
// exercise the shutdown-drain path deterministically. The block happens
// before any state mutation so Flush-timeout tests see the call
// "in-flight" (wg counter > 0) with no recorded entries yet.
if m.block != nil {
<-m.block
}
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, auditCall{
Method: method,
Path: path,
Actor: actor,
BodyHash: bodyHash,
Status: status,
LatencyMs: latencyMs,
})
return m.err
}
func (m *mockAuditRecorder) getCalls() []auditCall {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]auditCall, len(m.calls))
copy(out, m.calls)
return out
}
// waitableAuditRecorder wraps a mockAuditRecorder and signals when a recording completes.
// This allows tests to synchronously wait for async audit records without using time.Sleep.
type waitableAuditRecorder struct {
inner *mockAuditRecorder
recorded chan struct{}
}
func newWaitableAuditRecorder() *waitableAuditRecorder {
return &waitableAuditRecorder{
inner: &mockAuditRecorder{},
recorded: make(chan struct{}, 100), // buffered to avoid blocking
}
}
func (w *waitableAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
err := w.inner.RecordAPICall(ctx, method, path, actor, bodyHash, status, latencyMs)
// Signal that a recording was completed
select {
case w.recorded <- struct{}{}:
default:
}
return err
}
func (w *waitableAuditRecorder) getCalls() []auditCall {
return w.inner.getCalls()
}
// Wait blocks until a recording is signaled or timeout expires. Returns true if recording completed, false on timeout.
func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool {
select {
case <-w.recorded:
return true
case <-time.After(timeout):
return false
}
}
func TestAuditLog_RecordsAPICall(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
// Audit recording is async — wait for goroutine to complete
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Method != "GET" {
t.Errorf("expected method GET, got %s", calls[0].Method)
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
if calls[0].Actor != "anonymous" {
t.Errorf("expected actor anonymous, got %s", calls[0].Actor)
}
if calls[0].Status != 200 {
t.Errorf("expected status 200, got %d", calls[0].Status)
}
}
func TestAuditLog_CapturesStatusCode(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certs/mc-nonexistent", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Status != 404 {
t.Errorf("expected status 404, got %d", calls[0].Status)
}
}
func TestAuditLog_ExcludesHealth(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{
ExcludePaths: []string{"/health", "/ready"},
}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Health endpoint — should be excluded
req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Ready endpoint — should be excluded
req2 := httptest.NewRequest(http.MethodGet, "/ready", nil)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
// API endpoint — should be recorded
req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call (health/ready excluded), got %d", len(calls))
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
}
func TestAuditLog_HashesRequestBody(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
// Handler verifies body was restored
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
if string(body) != `{"name":"test"}` {
t.Errorf("body was not restored: got %q", string(body))
}
w.WriteHeader(http.StatusCreated)
}))
body := strings.NewReader(`{"name":"test"}`)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", body)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
// Audit 2026-05-10 MED-15 closure — body hash is now the full
// 64-char hex SHA-256 (was [:16] truncated). The body_hash schema
// column is CHAR(64); the truncation was an integrity-collision
// hole that allowed an attacker to craft tampered audit payloads
// matching the 16-hex prefix.
if len(calls[0].BodyHash) != 64 {
t.Errorf("expected 64-char SHA-256 body hash, got %q (len=%d)", calls[0].BodyHash, len(calls[0].BodyHash))
}
if calls[0].Status != 201 {
t.Errorf("expected status 201, got %d", calls[0].Status)
}
}
func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].BodyHash != "" {
t.Errorf("expected empty body hash for GET, got %q", calls[0].BodyHash)
}
}
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/mc-1", nil)
// Simulate auth middleware having set the named-key identity in context
// (post-M-002: actor is the named-key name, not the old "api-key-user").
ctx := context.WithValue(req.Context(), auth.UserKey{}, "ops-admin")
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].Actor != "ops-admin" {
t.Errorf("expected actor ops-admin, got %s", calls[0].Actor)
}
if calls[0].Method != "DELETE" {
t.Errorf("expected method DELETE, got %s", calls[0].Method)
}
}
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/stats/summary", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Response should still be 200 even though audit recording fails
if rr.Code != http.StatusOK {
t.Errorf("expected 200 despite recorder error, got %d", rr.Code)
}
}
func TestAuditLog_CapturesLatency(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
if calls[0].LatencyMs < 10 {
t.Errorf("expected latency >= 10ms, got %dms", calls[0].LatencyMs)
}
}
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Send a request with sensitive query parameters
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?api_key=secret123&cursor=abc&status=active", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !recorder.Wait(1 * time.Second) {
t.Fatal("timeout waiting for audit record")
}
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 audit call, got %d", len(calls))
}
// Path should contain ONLY the path, no query parameters
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates (no query params), got %s", calls[0].Path)
}
if strings.Contains(calls[0].Path, "api_key") {
t.Error("audit path contains 'api_key' — query parameters leaked into audit trail")
}
if strings.Contains(calls[0].Path, "secret123") {
t.Error("audit path contains sensitive value 'secret123' — query parameters leaked into audit trail")
}
if strings.Contains(calls[0].Path, "cursor") {
t.Error("audit path contains 'cursor' — query parameters leaked into audit trail")
}
if strings.Contains(calls[0].Path, "?") {
t.Error("audit path contains '?' — query string leaked into audit trail")
}
}
func TestAuditServiceAdapter_TranslatesCallToEvent(t *testing.T) {
var capturedActor, capturedActorType, capturedAction, capturedResourceType, capturedResourceID string
var capturedDetails map[string]interface{}
adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error {
capturedActor = actor
capturedActorType = actorType
capturedAction = action
capturedResourceType = resourceType
capturedResourceID = resourceID
capturedDetails = details
return nil
})
err := adapter.RecordAPICall(context.Background(), "POST", "/api/v1/certificates", "admin", "abc123", 201, 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedActor != "admin" {
t.Errorf("expected actor admin, got %s", capturedActor)
}
if capturedActorType != "User" {
t.Errorf("expected actorType User, got %s", capturedActorType)
}
if capturedAction != "api_post" {
t.Errorf("expected action api_post, got %s", capturedAction)
}
if capturedResourceType != "api" {
t.Errorf("expected resourceType api, got %s", capturedResourceType)
}
if capturedResourceID != "/api/v1/certificates" {
t.Errorf("expected resourceID /api/v1/certificates, got %s", capturedResourceID)
}
if capturedDetails["method"] != "POST" {
t.Errorf("expected details.method POST, got %v", capturedDetails["method"])
}
if capturedDetails["status"] != 201 {
t.Errorf("expected details.status 201, got %v", capturedDetails["status"])
}
if capturedDetails["latency_ms"] != int64(42) {
t.Errorf("expected details.latency_ms 42, got %v", capturedDetails["latency_ms"])
}
if capturedDetails["body_hash"] != "abc123" {
t.Errorf("expected details.body_hash abc123, got %v", capturedDetails["body_hash"])
}
}
func TestAuditServiceAdapter_PropagatesError(t *testing.T) {
adapter := NewAuditServiceAdapter(func(ctx context.Context, actor, actorType, action, resourceType, resourceID string, details map[string]interface{}) error {
return fmt.Errorf("database error")
})
err := adapter.RecordAPICall(context.Background(), "GET", "/api/v1/agents", "user", "", 200, 5)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "database error") {
t.Errorf("expected database error, got %v", err)
}
}
// TestAuditLog_FlushDrainsInFlightGoroutines verifies the M-1 shutdown-drain
// contract: Flush blocks until every audit-recording goroutine spawned by the
// middleware completes, then returns nil. Without the drain (pre-M-1 code),
// the DB pool would be closed while in-flight goroutines were still calling
// RecordAPICall, silently dropping audit events (CWE-662 / CWE-400).
func TestAuditLog_FlushDrainsInFlightGoroutines(t *testing.T) {
// Recorder blocks on `unblock` until the test releases it. This simulates
// a slow DB write still in flight when shutdown begins.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Fire a request. Handler returns immediately; recorder goroutine is
// parked on the `unblock` channel inside RecordAPICall.
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Start Flush in a goroutine — it must block on the WaitGroup until we
// release the recorder.
flushDone := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
flushDone <- auditMW.Flush(ctx)
}()
// Confirm Flush is actually blocked (not returning immediately).
select {
case err := <-flushDone:
t.Fatalf("Flush returned before recorder unblocked: err=%v", err)
case <-time.After(50 * time.Millisecond):
// expected: Flush is blocked on wg.Wait
}
// Release the recorder. Flush should now observe wg counter drop to 0
// and return nil.
close(unblock)
select {
case err := <-flushDone:
if err != nil {
t.Fatalf("expected nil from Flush after drain, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Flush did not return after recorder unblocked")
}
// Verify the audit event was actually recorded (i.e., the goroutine
// completed its write — not just that Flush unblocked).
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 recorded audit call, got %d", len(calls))
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
}
// TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout verifies that Flush
// respects its context: when in-flight goroutines exceed the shutdown budget,
// Flush returns an error wrapping ErrAuditFlushTimeout plus ctx.Err(). The
// caller can then decide whether to proceed with teardown anyway.
func TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout(t *testing.T) {
// Recorder will never unblock on its own — we unblock at end of test for
// a clean race-safe teardown.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Flush with a tiny deadline — must time out.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := auditMW.Flush(ctx)
if err == nil {
// Release the blocked goroutine before failing so the race detector
// doesn't trip on teardown.
close(unblock)
t.Fatal("expected Flush to return an error on timeout, got nil")
}
if !errors.Is(err, ErrAuditFlushTimeout) {
close(unblock)
t.Fatalf("expected error to wrap ErrAuditFlushTimeout, got %v", err)
}
if !errors.Is(err, context.DeadlineExceeded) {
close(unblock)
t.Fatalf("expected error to wrap context.DeadlineExceeded, got %v", err)
}
// Race-safe teardown: unblock the recorder goroutine so it exits cleanly
// before the test returns. The goroutine itself is still detached and
// will record to the mock even after Flush timed out — that's the
// documented behavior (Flush surfaces the timeout; caller decides).
close(unblock)
}