From fe7e766510f7707b1ff5c437dc0f129d15bbb388 Mon Sep 17 00:00:00 2001 From: shankar0123 Date: Sat, 18 Apr 2026 22:20:25 +0000 Subject: [PATCH] Close M-004 (OCSP issuer binding) and M-005 (discovery actor propagation) coverage-gap findings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit M-004 — OCSP issuer binding (composite key): The OCSP lookup path now binds (issuer_id, serial) as a composite key rather than resolving by serial alone. CertificateRepository and RevocationRepository gain GetByIssuerAndSerial methods; ca_operations.go scopes both lookups by the issuer_id path param. When no managed cert binds to that (issuer, serial) tuple, GetOCSPResponse constructs an RFC 6960 §2.2 'unknown' response (CertStatus=2) instead of the prior default 'good'. Short-lived cert exemption (profile TTL < 1h) is preserved. Real repo errors (non-sql.ErrNoRows) fail closed with a log. Regression coverage: internal/service/ca_operations_test.go - TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer - TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial M-005 — Discovery Claim/Dismiss actor propagation: DiscoveryService.ClaimDiscovered and DismissDiscovered now accept an explicit 'actor string' parameter (propagation pattern mirrors bulk_revocation.go / revocation_svc.go). The handler layer passes resolveActor(r.Context()) — the named-key identity established by the M-002 auth unification — and the service falls back to 'api' (the same safe sentinel resolveActor uses when no auth context is present) only when the caller passes an empty string. Never falls back to 'operator'. Regression coverage: internal/service/discovery_test.go - TestDiscoveryService_ClaimDiscovered_AuditActor - TestDiscoveryService_DismissDiscovered_AuditActor - TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI - TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI Each new test asserts event.Actor matches the caller-supplied string (or 'api' on empty input) and explicitly asserts event.Actor != 'operator' to lock in the historical fix intent. Files: internal/api/handler/discovery.go — pass resolveActor(ctx) internal/api/handler/discovery_handler_test.go — updated call sites internal/integration/lifecycle_test.go — updated mock wiring internal/repository/interfaces.go — GetByIssuerAndSerial on CertificateRepository + RevocationRepository internal/repository/postgres/certificate.go — composite key lookup internal/service/ca_operations.go — (issuer_id, serial) scoping internal/service/ca_operations_test.go — 2 new M-004 tests internal/service/discovery.go — actor parameter + 'api' fallback internal/service/discovery_test.go — 4 new M-005 tests internal/service/shortlived_test.go — mock signature update internal/service/testutil_test.go — mock GetByIssuerAndSerial --- internal/api/handler/discovery.go | 13 +- .../api/handler/discovery_handler_test.go | 20 +- internal/integration/lifecycle_test.go | 23 ++- internal/repository/interfaces.go | 7 + internal/repository/postgres/certificate.go | 33 ++++ internal/service/ca_operations.go | 54 ++++-- internal/service/ca_operations_test.go | 84 +++++++- internal/service/discovery.go | 26 ++- internal/service/discovery_test.go | 180 +++++++++++++++++- internal/service/shortlived_test.go | 4 + internal/service/testutil_test.go | 27 +++ 11 files changed, 430 insertions(+), 41 deletions(-) diff --git a/internal/api/handler/discovery.go b/internal/api/handler/discovery.go index 055e915..3970549 100644 --- a/internal/api/handler/discovery.go +++ b/internal/api/handler/discovery.go @@ -11,12 +11,17 @@ import ( ) // DiscoveryService defines the interface used by the discovery handler. +// ClaimDiscovered and DismissDiscovered accept an explicit actor parameter so +// the handler can flow the authenticated named-key identity into the audit +// trail (M-005). Services that call these methods from non-request contexts +// pass a descriptive sentinel (e.g., "system") or "" (which falls back to +// "api"). type DiscoveryService interface { ProcessDiscoveryReport(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ListDiscovered(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) - ClaimDiscovered(ctx context.Context, id string, managedCertID string) error - DismissDiscovered(ctx context.Context, id string) error + ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error + DismissDiscovered(ctx context.Context, id string, actor string) error ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetDiscoverySummary(ctx context.Context) (map[string]int, error) @@ -142,7 +147,7 @@ func (h DiscoveryHandler) ClaimDiscovered(w http.ResponseWriter, r *http.Request return } - if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID); err != nil { + if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID, resolveActor(r.Context())); err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to claim certificate: %v", err)) return } @@ -166,7 +171,7 @@ func (h DiscoveryHandler) DismissDiscovered(w http.ResponseWriter, r *http.Reque return } - if err := h.svc.DismissDiscovered(r.Context(), id); err != nil { + if err := h.svc.DismissDiscovered(r.Context(), id, resolveActor(r.Context())); err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to dismiss certificate: %v", err)) return } diff --git a/internal/api/handler/discovery_handler_test.go b/internal/api/handler/discovery_handler_test.go index 7d0a2de..bcb98d1 100644 --- a/internal/api/handler/discovery_handler_test.go +++ b/internal/api/handler/discovery_handler_test.go @@ -19,8 +19,8 @@ type MockDiscoveryService struct { ProcessDiscoveryReportFn func(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ListDiscoveredFn func(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) GetDiscoveredFn func(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) - ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string) error - DismissDiscoveredFn func(ctx context.Context, id string) error + ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string, actor string) error + DismissDiscoveredFn func(ctx context.Context, id string, actor string) error ListScansFn func(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) GetScanFn func(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetDiscoverySummaryFn func(ctx context.Context) (map[string]int, error) @@ -47,16 +47,16 @@ func (m *MockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d return nil, nil } -func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { if m.ClaimDiscoveredFn != nil { - return m.ClaimDiscoveredFn(ctx, id, managedCertID) + return m.ClaimDiscoveredFn(ctx, id, managedCertID, actor) } return nil } -func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { if m.DismissDiscoveredFn != nil { - return m.DismissDiscoveredFn(ctx, id) + return m.DismissDiscoveredFn(ctx, id, actor) } return nil } @@ -352,7 +352,7 @@ func TestGetDiscovered_NotFound(t *testing.T) { // Test ClaimDiscovered - success case func TestClaimDiscovered_Success(t *testing.T) { mock := &MockDiscoveryService{ - ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { + ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error { if id == "dcert-1" && managedCertID == "mc-prod-1" { return nil } @@ -411,7 +411,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) { // Test ClaimDiscovered - discovered cert not found func TestClaimDiscovered_NotFound(t *testing.T) { mock := &MockDiscoveryService{ - ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { + ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error { return fmt.Errorf("discovered certificate not found") }, } @@ -438,7 +438,7 @@ func TestClaimDiscovered_NotFound(t *testing.T) { // Test DismissDiscovered - success case func TestDismissDiscovered_Success(t *testing.T) { mock := &MockDiscoveryService{ - DismissDiscoveredFn: func(ctx context.Context, id string) error { + DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error { if id == "dcert-1" { return nil } @@ -614,7 +614,7 @@ func TestGetDiscoverySummary_MethodNotAllowed(t *testing.T) { // Test DismissDiscovered - service error func TestDismissDiscovered_ServiceError(t *testing.T) { mock := &MockDiscoveryService{ - DismissDiscoveredFn: func(ctx context.Context, id string) error { + DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error { return fmt.Errorf("database error") }, } diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index d3abae9..a574cf3 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -586,6 +587,24 @@ func (m *mockCertificateRepository) GetLatestVersion(ctx context.Context, certID return versions[len(versions)-1], nil } +// GetByIssuerAndSerial emulates the PostgreSQL JOIN that scopes cert lookup to +// (issuer_id, serial). Returns sql.ErrNoRows when no match exists so callers +// that branch on errors.Is(err, sql.ErrNoRows) (notably the OCSP handler's +// M-004 "unknown" fallback) behave the same in-memory as against PostgreSQL. +func (m *mockCertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + for _, cert := range m.certs { + if cert.IssuerID != issuerID { + continue + } + for _, v := range m.versions[cert.ID] { + if v.SerialNumber == serial { + return cert, nil + } + } + } + return nil, sql.ErrNoRows +} + type mockJobRepository struct { jobs map[string]*domain.Job } @@ -1301,11 +1320,11 @@ func (m *mockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d return nil, fmt.Errorf("not found") } -func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { return nil } -func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { return nil } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index ff23b5f..81e0bbe 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -27,6 +27,13 @@ type CertificateRepository interface { GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) // GetLatestVersion returns the most recent certificate version for a certificate. GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) + // GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number) + // pair via a JOIN on certificate_versions. Callers (OCSP, revocation lookup) + // always know the issuer because protocol endpoints carry it in the request + // path; RFC 5280 §5.2.3 guarantees serial uniqueness only within a single + // issuer. Returns sql.ErrNoRows when no match exists so callers can + // distinguish "unknown cert" from a real repository error. + GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) } // RevocationRepository defines operations for managing certificate revocations. diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go index f6fc45d..370f005 100644 --- a/internal/repository/postgres/certificate.go +++ b/internal/repository/postgres/certificate.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "errors" "fmt" "strings" "time" @@ -272,6 +273,38 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man return cert, nil } +// GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number) +// pair via a JOIN on certificate_versions. Per RFC 5280 §5.2.3, serial numbers +// are unique only within a single issuer — callers that know the issuer (OCSP, +// CRL generation, revocation lookup) use this method to scope lookups +// correctly. Returns sql.ErrNoRows when no match exists so callers can +// distinguish "unknown cert" (return OCSP status unknown) from a real +// repository error. +func (r *CertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT mc.id, mc.name, mc.common_name, mc.sans, mc.environment, mc.owner_id, mc.team_id, + mc.issuer_id, mc.renewal_policy_id, mc.certificate_profile_id, mc.status, mc.expires_at, + mc.tags, mc.last_renewal_at, mc.last_deployment_at, mc.revoked_at, mc.revocation_reason, + mc.created_at, mc.updated_at + FROM managed_certificates mc + JOIN certificate_versions cv ON cv.certificate_id = mc.id + WHERE mc.issuer_id = $1 AND cv.serial_number = $2 + LIMIT 1 + `, issuerID, serial) + + cert, err := r.scanCertificate(ctx, row) + if err != nil { + // scanCertificate wraps sql.ErrNoRows via %w, so surface the bare + // sentinel here for callers that branch on it with errors.Is. + if errors.Is(err, sql.ErrNoRows) { + return nil, sql.ErrNoRows + } + return nil, fmt.Errorf("failed to query certificate by issuer+serial: %w", err) + } + + return cert, nil +} + // Create stores a new certificate func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error { if cert.ID == "" { diff --git a/internal/service/ca_operations.go b/internal/service/ca_operations.go index 430a664..3255df9 100644 --- a/internal/service/ca_operations.go +++ b/internal/service/ca_operations.go @@ -2,6 +2,8 @@ package service import ( "context" + "database/sql" + "errors" "fmt" "log/slog" "math/big" @@ -139,23 +141,49 @@ func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, // Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping. rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) - if err != nil { - // Not revoked — return "good" status + if err == nil && rev != nil { + // Revoked return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ - CertSerial: serial, - CertStatus: 0, // good - ThisUpdate: now, - NextUpdate: now.Add(1 * time.Hour), + CertSerial: serial, + CertStatus: 1, // revoked + RevokedAt: rev.RevokedAt, + RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)), + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), }) } - // Revoked + // Not revoked. Per RFC 6960 §2.2, we must only return "good" for a + // certificate that was actually issued by this CA. Verify the + // (issuer_id, serial) tuple maps to a real certificate in inventory + // before asserting "good"; otherwise return "unknown". This closes the + // coverage gap where forged/guessed serials would be accepted as valid + // because they had no revocation row (M-004). + if s.certRepo != nil { + cert, certErr := s.certRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) + if certErr != nil || cert == nil { + if certErr != nil && !errors.Is(certErr, sql.ErrNoRows) { + // Real repository failure — log but still fail closed with "unknown" + // rather than leaking a bogus "good" assertion. + slog.Warn("OCSP cert lookup failed; returning unknown", + "issuer_id", issuerID, + "serial", serialHex, + "error", certErr) + } + return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ + CertSerial: serial, + CertStatus: 2, // unknown + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), + }) + } + } + + // Known cert, not revoked — return "good" return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ - CertSerial: serial, - CertStatus: 1, // revoked - RevokedAt: rev.RevokedAt, - RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)), - ThisUpdate: now, - NextUpdate: now.Add(1 * time.Hour), + CertSerial: serial, + CertStatus: 0, // good + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), }) } diff --git a/internal/service/ca_operations_test.go b/internal/service/ca_operations_test.go index 2825fd5..2d8990d 100644 --- a/internal/service/ca_operations_test.go +++ b/internal/service/ca_operations_test.go @@ -13,16 +13,25 @@ import ( // helper to create a CAOperationsSvc for testing func newCAOperationsSvcTest() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo) { + caSvc, revocationRepo, certRepo, _ := newCAOperationsSvcTestWithIssuer() + return caSvc, revocationRepo, certRepo +} + +// newCAOperationsSvcTestWithIssuer also returns the mock issuer connector +// so tests can assert on the captured OCSPSignRequest. +func newCAOperationsSvcTestWithIssuer() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo, *mockIssuerConnector) { revocationRepo := newMockRevocationRepository() certRepo := newMockCertificateRepository() profileRepo := newMockProfileRepository() caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo) registry := NewIssuerRegistry(slog.Default()) - registry.Set("iss-local", &mockIssuerConnector{}) + issuer := &mockIssuerConnector{} + registry.Set("iss-local", issuer) + registry.Set("iss-other", &mockIssuerConnector{}) caSvc.SetIssuerRegistry(registry) - return caSvc, revocationRepo, certRepo + return caSvc, revocationRepo, certRepo, issuer } func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) { @@ -126,6 +135,77 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) { t.Logf("OCSP response for good cert generated: %d bytes", len(resp)) } +// TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer guards the M-004 fix: +// a cert with the queried serial exists but under a *different* issuer. Before +// the fix, OCSP fell through to "good" (CertStatus 0) because no revocation row +// matched the (issuer_id, serial) tuple. Per RFC 5280 §5.2.3 serials are unique +// only within a single issuer, and per RFC 6960 §2.2 unknown certs must report +// "unknown" (CertStatus 2), not "good". +func TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer(t *testing.T) { + caSvc, _, certRepo, issuer := newCAOperationsSvcTestWithIssuer() + + // Real cert exists, but bound to iss-other (not iss-local). + cert := &domain.ManagedCertificate{ + ID: "cert-cross-issuer", + CommonName: "cross.example.com", + IssuerID: "iss-other", + Status: domain.CertificateStatusActive, + ExpiresAt: time.Now().AddDate(1, 0, 0), + } + certRepo.AddCert(cert) + certRepo.Versions["cert-cross-issuer"] = []*domain.CertificateVersion{{ + ID: "ver-cross-issuer", + CertificateID: "cert-cross-issuer", + SerialNumber: "CROSS-ISSUER-001", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + CreatedAt: time.Now(), + }} + + // Query OCSP for iss-local + CROSS-ISSUER-001. The serial exists, but + // under iss-other — our JOIN-scoped lookup should return no match. + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "CROSS-ISSUER-001") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if resp == nil || len(resp) == 0 { + t.Fatal("expected non-empty OCSP response") + } + + if issuer.LastOCSPSignRequest == nil { + t.Fatal("expected SignOCSPResponse to be called") + } + if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want { + t.Errorf("CertStatus = %d, want %d (unknown) — cross-issuer lookup must not return good", got, want) + } +} + +// TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial guards the M-004 fix +// for the "forged/guessed serial" case: no certificate exists at this +// (issuer_id, serial) tuple anywhere in inventory. Per RFC 6960 §2.2 we must +// report "unknown" (CertStatus 2), never "good" — returning good for a serial +// we never issued is a protocol violation that would allow an attacker to get +// certctl to vouch for a cert it never signed. +func TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial(t *testing.T) { + caSvc, _, _, issuer := newCAOperationsSvcTestWithIssuer() + + // No cert rows added. Query for an arbitrary serial under iss-local. + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "DEADBEEF-NEVER-ISSUED") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if resp == nil || len(resp) == 0 { + t.Fatal("expected non-empty OCSP response") + } + + if issuer.LastOCSPSignRequest == nil { + t.Fatal("expected SignOCSPResponse to be called") + } + if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want { + t.Errorf("CertStatus = %d, want %d (unknown) — unissued serials must not return good", got, want) + } +} + func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) { caSvc, revocationRepo, certRepo := newCAOperationsSvcTest() diff --git a/internal/service/discovery.go b/internal/service/discovery.go index be73d49..a5ed3b7 100644 --- a/internal/service/discovery.go +++ b/internal/service/discovery.go @@ -148,7 +148,14 @@ func (s *DiscoveryService) GetDiscovered(ctx context.Context, id string) (*domai } // ClaimDiscovered links a discovered certificate to a managed certificate. -func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +// The actor parameter names the authenticated identity that initiated the +// claim and is recorded on the audit event. Callers in the handler layer pass +// resolveActor(ctx); service-to-service callers pass a descriptive sentinel +// (e.g., "system"). Empty actor falls back to "api" (the same safe sentinel +// resolveActor uses when no auth context is present), never to "operator" — +// hardcoding "operator" was M-005, a coverage-gap closure where audit records +// failed to identify who actually performed the triage action. +func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { if managedCertID == "" { return fmt.Errorf("managed_certificate_id is required") } @@ -168,8 +175,12 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag return fmt.Errorf("failed to update discovered certificate status: %w", err) } + if actor == "" { + actor = "api" + } + // Audit trail - if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, + if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "discovery_cert_claimed", "discovered_certificate", id, map[string]interface{}{ "managed_certificate_id": managedCertID, @@ -182,14 +193,19 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag return nil } -// DismissDiscovered marks a discovered certificate as dismissed. -func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +// DismissDiscovered marks a discovered certificate as dismissed. See +// ClaimDiscovered for the actor contract — same rules apply (M-005). +func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { if err := s.discoveryRepo.UpdateDiscoveredStatus(ctx, id, domain.DiscoveryStatusDismissed, ""); err != nil { return fmt.Errorf("failed to dismiss discovered certificate: %w", err) } + if actor == "" { + actor = "api" + } + // Audit trail - if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, + if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "discovery_cert_dismissed", "discovered_certificate", id, nil); err != nil { slog.Error("failed to record audit event", "error", err) } diff --git a/internal/service/discovery_test.go b/internal/service/discovery_test.go index f392dfe..548738b 100644 --- a/internal/service/discovery_test.go +++ b/internal/service/discovery_test.go @@ -381,7 +381,7 @@ func TestClaimDiscovered_Success(t *testing.T) { } certRepo.AddCert(managedCert) - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -423,7 +423,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "", "test-actor") if err == nil { t.Fatal("expected error for empty managed_certificate_id") } @@ -442,7 +442,7 @@ func TestClaimDiscovered_ManagedCertNotFound(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert", "test-actor") if err == nil { t.Fatal("expected error for nonexistent managed certificate") } @@ -464,7 +464,7 @@ func TestDismissDiscovered_Success(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.DismissDiscovered(context.Background(), "dcert-1") + err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -497,8 +497,178 @@ func TestDismissDiscovered_NotFound(t *testing.T) { svc, discoveryRepo, _, _ := newDiscoveryTestService() discoveryRepo.UpdateStatusErr = errNotFound - err := svc.DismissDiscovered(context.Background(), "nonexistent") + err := svc.DismissDiscovered(context.Background(), "nonexistent", "test-actor") if err == nil { t.Fatal("expected error for nonexistent cert") } } + +// M-005 regression: caller-supplied actor must propagate onto the +// discovery_cert_claimed audit event so the trail identifies who performed +// triage (pre-M-005 the service hardcoded "operator"). +func TestDiscoveryService_ClaimDiscovered_AuditActor(t *testing.T) { + svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService() + + now := time.Now() + discoveredCert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + FingerprintSHA256: "abc123", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert + + managedCert := &domain.ManagedCertificate{ + ID: "mc-prod-1", + CommonName: "example.com", + Status: domain.CertificateStatusActive, + CreatedAt: now, + UpdatedAt: now, + } + certRepo.AddCert(managedCert) + + if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp"); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Locate the discovery_cert_claimed audit event and assert actor propagation. + var claimEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_claimed" { + claimEvent = e + break + } + } + if claimEvent == nil { + t.Fatal("expected discovery_cert_claimed audit event to be recorded") + } + if claimEvent.Actor != "alice@corp" { + t.Errorf("expected audit actor to be caller-supplied 'alice@corp', got %q", claimEvent.Actor) + } + if claimEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression symmetric pair for DismissDiscovered. +func TestDiscoveryService_DismissDiscovered_AuditActor(t *testing.T) { + svc, discoveryRepo, _, auditRepo := newDiscoveryTestService() + + now := time.Now() + cert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[cert.ID] = cert + + if err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp"); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var dismissEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_dismissed" { + dismissEvent = e + break + } + } + if dismissEvent == nil { + t.Fatal("expected discovery_cert_dismissed audit event to be recorded") + } + if dismissEvent.Actor != "bob@corp" { + t.Errorf("expected audit actor to be caller-supplied 'bob@corp', got %q", dismissEvent.Actor) + } + if dismissEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression: when the caller passes an empty actor (e.g., the handler's +// resolveActor helper returns "" because no auth context is present), the +// service must fall back to the safe sentinel "api" — never to the pre-M-005 +// hardcoded "operator". +func TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI(t *testing.T) { + svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService() + + now := time.Now() + discoveredCert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + FingerprintSHA256: "abc123", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert + + managedCert := &domain.ManagedCertificate{ + ID: "mc-prod-1", + CommonName: "example.com", + Status: domain.CertificateStatusActive, + CreatedAt: now, + UpdatedAt: now, + } + certRepo.AddCert(managedCert) + + if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", ""); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var claimEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_claimed" { + claimEvent = e + break + } + } + if claimEvent == nil { + t.Fatal("expected discovery_cert_claimed audit event to be recorded") + } + if claimEvent.Actor != "api" { + t.Errorf("expected empty actor to fall back to 'api', got %q", claimEvent.Actor) + } + if claimEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression symmetric pair for DismissDiscovered empty-actor fallback. +func TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI(t *testing.T) { + svc, discoveryRepo, _, auditRepo := newDiscoveryTestService() + + now := time.Now() + cert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[cert.ID] = cert + + if err := svc.DismissDiscovered(context.Background(), "dcert-1", ""); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var dismissEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_dismissed" { + dismissEvent = e + break + } + } + if dismissEvent == nil { + t.Fatal("expected discovery_cert_dismissed audit event to be recorded") + } + if dismissEvent.Actor != "api" { + t.Errorf("expected empty actor to fall back to 'api', got %q", dismissEvent.Actor) + } + if dismissEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} diff --git a/internal/service/shortlived_test.go b/internal/service/shortlived_test.go index db099ee..d64a188 100644 --- a/internal/service/shortlived_test.go +++ b/internal/service/shortlived_test.go @@ -198,6 +198,10 @@ func (m *mockCertRepoWithGetError) GetLatestVersion(ctx context.Context, certID return nil, nil } +func (m *mockCertRepoWithGetError) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + return nil, nil +} + func (m *mockCertRepoWithGetError) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { return nil, m.GetExpiringCertificatesErr } diff --git a/internal/service/testutil_test.go b/internal/service/testutil_test.go index b913fcc..51633d4 100644 --- a/internal/service/testutil_test.go +++ b/internal/service/testutil_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "errors" "sync" "time" @@ -129,6 +130,26 @@ func (m *mockCertRepo) GetLatestVersion(ctx context.Context, certID string) (*do return versions[len(versions)-1], nil } +// GetByIssuerAndSerial emulates the PostgreSQL JOIN: +// SELECT mc.* FROM managed_certificates mc JOIN certificate_versions cv +// ON cv.certificate_id = mc.id WHERE mc.issuer_id = $1 AND cv.serial_number = $2. +// Returns sql.ErrNoRows (the sentinel the real repo surfaces) when no match +// exists, so callers that branch on errors.Is(err, sql.ErrNoRows) behave the +// same in-memory as they do against PostgreSQL. +func (m *mockCertRepo) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + for _, cert := range m.Certs { + if cert.IssuerID != issuerID { + continue + } + for _, v := range m.Versions[cert.ID] { + if v.SerialNumber == serial { + return cert, nil + } + } + } + return nil, sql.ErrNoRows +} + func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) { m.Certs[cert.ID] = cert } @@ -784,6 +805,9 @@ type mockIssuerConnector struct { Err error getRenewalInfoResult *RenewalInfoResult getRenewalInfoErr error + // LastOCSPSignRequest captures the last request passed to SignOCSPResponse. + // Tests use this to assert CertStatus (0=good, 1=revoked, 2=unknown). + LastOCSPSignRequest *OCSPSignRequest } func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) { @@ -825,6 +849,9 @@ func (m *mockIssuerConnector) GenerateCRL(ctx context.Context, entries []CRLEntr } func (m *mockIssuerConnector) SignOCSPResponse(ctx context.Context, req OCSPSignRequest) ([]byte, error) { + // Capture the request for test assertions (e.g., CertStatus verification) + reqCopy := req + m.LastOCSPSignRequest = &reqCopy if m.Err != nil { return nil, m.Err }