diff --git a/internal/api/handler/discovery.go b/internal/api/handler/discovery.go index 055e915..3970549 100644 --- a/internal/api/handler/discovery.go +++ b/internal/api/handler/discovery.go @@ -11,12 +11,17 @@ import ( ) // DiscoveryService defines the interface used by the discovery handler. +// ClaimDiscovered and DismissDiscovered accept an explicit actor parameter so +// the handler can flow the authenticated named-key identity into the audit +// trail (M-005). Services that call these methods from non-request contexts +// pass a descriptive sentinel (e.g., "system") or "" (which falls back to +// "api"). type DiscoveryService interface { ProcessDiscoveryReport(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ListDiscovered(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) - ClaimDiscovered(ctx context.Context, id string, managedCertID string) error - DismissDiscovered(ctx context.Context, id string) error + ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error + DismissDiscovered(ctx context.Context, id string, actor string) error ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetDiscoverySummary(ctx context.Context) (map[string]int, error) @@ -142,7 +147,7 @@ func (h DiscoveryHandler) ClaimDiscovered(w http.ResponseWriter, r *http.Request return } - if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID); err != nil { + if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID, resolveActor(r.Context())); err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to claim certificate: %v", err)) return } @@ -166,7 +171,7 @@ func (h DiscoveryHandler) DismissDiscovered(w http.ResponseWriter, r *http.Reque return } - if err := h.svc.DismissDiscovered(r.Context(), id); err != nil { + if err := h.svc.DismissDiscovered(r.Context(), id, resolveActor(r.Context())); err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to dismiss certificate: %v", err)) return } diff --git a/internal/api/handler/discovery_handler_test.go b/internal/api/handler/discovery_handler_test.go index 7d0a2de..bcb98d1 100644 --- a/internal/api/handler/discovery_handler_test.go +++ b/internal/api/handler/discovery_handler_test.go @@ -19,8 +19,8 @@ type MockDiscoveryService struct { ProcessDiscoveryReportFn func(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ListDiscoveredFn func(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) GetDiscoveredFn func(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) - ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string) error - DismissDiscoveredFn func(ctx context.Context, id string) error + ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string, actor string) error + DismissDiscoveredFn func(ctx context.Context, id string, actor string) error ListScansFn func(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) GetScanFn func(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetDiscoverySummaryFn func(ctx context.Context) (map[string]int, error) @@ -47,16 +47,16 @@ func (m *MockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d return nil, nil } -func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { if m.ClaimDiscoveredFn != nil { - return m.ClaimDiscoveredFn(ctx, id, managedCertID) + return m.ClaimDiscoveredFn(ctx, id, managedCertID, actor) } return nil } -func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { if m.DismissDiscoveredFn != nil { - return m.DismissDiscoveredFn(ctx, id) + return m.DismissDiscoveredFn(ctx, id, actor) } return nil } @@ -352,7 +352,7 @@ func TestGetDiscovered_NotFound(t *testing.T) { // Test ClaimDiscovered - success case func TestClaimDiscovered_Success(t *testing.T) { mock := &MockDiscoveryService{ - ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { + ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error { if id == "dcert-1" && managedCertID == "mc-prod-1" { return nil } @@ -411,7 +411,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) { // Test ClaimDiscovered - discovered cert not found func TestClaimDiscovered_NotFound(t *testing.T) { mock := &MockDiscoveryService{ - ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { + ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error { return fmt.Errorf("discovered certificate not found") }, } @@ -438,7 +438,7 @@ func TestClaimDiscovered_NotFound(t *testing.T) { // Test DismissDiscovered - success case func TestDismissDiscovered_Success(t *testing.T) { mock := &MockDiscoveryService{ - DismissDiscoveredFn: func(ctx context.Context, id string) error { + DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error { if id == "dcert-1" { return nil } @@ -614,7 +614,7 @@ func TestGetDiscoverySummary_MethodNotAllowed(t *testing.T) { // Test DismissDiscovered - service error func TestDismissDiscovered_ServiceError(t *testing.T) { mock := &MockDiscoveryService{ - DismissDiscoveredFn: func(ctx context.Context, id string) error { + DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error { return fmt.Errorf("database error") }, } diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index d3abae9..a574cf3 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -586,6 +587,24 @@ func (m *mockCertificateRepository) GetLatestVersion(ctx context.Context, certID return versions[len(versions)-1], nil } +// GetByIssuerAndSerial emulates the PostgreSQL JOIN that scopes cert lookup to +// (issuer_id, serial). Returns sql.ErrNoRows when no match exists so callers +// that branch on errors.Is(err, sql.ErrNoRows) (notably the OCSP handler's +// M-004 "unknown" fallback) behave the same in-memory as against PostgreSQL. +func (m *mockCertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + for _, cert := range m.certs { + if cert.IssuerID != issuerID { + continue + } + for _, v := range m.versions[cert.ID] { + if v.SerialNumber == serial { + return cert, nil + } + } + } + return nil, sql.ErrNoRows +} + type mockJobRepository struct { jobs map[string]*domain.Job } @@ -1301,11 +1320,11 @@ func (m *mockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d return nil, fmt.Errorf("not found") } -func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { return nil } -func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { return nil } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index ff23b5f..81e0bbe 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -27,6 +27,13 @@ type CertificateRepository interface { GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) // GetLatestVersion returns the most recent certificate version for a certificate. GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) + // GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number) + // pair via a JOIN on certificate_versions. Callers (OCSP, revocation lookup) + // always know the issuer because protocol endpoints carry it in the request + // path; RFC 5280 §5.2.3 guarantees serial uniqueness only within a single + // issuer. Returns sql.ErrNoRows when no match exists so callers can + // distinguish "unknown cert" from a real repository error. + GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) } // RevocationRepository defines operations for managing certificate revocations. diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go index f6fc45d..370f005 100644 --- a/internal/repository/postgres/certificate.go +++ b/internal/repository/postgres/certificate.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "errors" "fmt" "strings" "time" @@ -272,6 +273,38 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man return cert, nil } +// GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number) +// pair via a JOIN on certificate_versions. Per RFC 5280 §5.2.3, serial numbers +// are unique only within a single issuer — callers that know the issuer (OCSP, +// CRL generation, revocation lookup) use this method to scope lookups +// correctly. Returns sql.ErrNoRows when no match exists so callers can +// distinguish "unknown cert" (return OCSP status unknown) from a real +// repository error. +func (r *CertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT mc.id, mc.name, mc.common_name, mc.sans, mc.environment, mc.owner_id, mc.team_id, + mc.issuer_id, mc.renewal_policy_id, mc.certificate_profile_id, mc.status, mc.expires_at, + mc.tags, mc.last_renewal_at, mc.last_deployment_at, mc.revoked_at, mc.revocation_reason, + mc.created_at, mc.updated_at + FROM managed_certificates mc + JOIN certificate_versions cv ON cv.certificate_id = mc.id + WHERE mc.issuer_id = $1 AND cv.serial_number = $2 + LIMIT 1 + `, issuerID, serial) + + cert, err := r.scanCertificate(ctx, row) + if err != nil { + // scanCertificate wraps sql.ErrNoRows via %w, so surface the bare + // sentinel here for callers that branch on it with errors.Is. + if errors.Is(err, sql.ErrNoRows) { + return nil, sql.ErrNoRows + } + return nil, fmt.Errorf("failed to query certificate by issuer+serial: %w", err) + } + + return cert, nil +} + // Create stores a new certificate func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error { if cert.ID == "" { diff --git a/internal/service/ca_operations.go b/internal/service/ca_operations.go index 430a664..3255df9 100644 --- a/internal/service/ca_operations.go +++ b/internal/service/ca_operations.go @@ -2,6 +2,8 @@ package service import ( "context" + "database/sql" + "errors" "fmt" "log/slog" "math/big" @@ -139,23 +141,49 @@ func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, // Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping. rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) - if err != nil { - // Not revoked — return "good" status + if err == nil && rev != nil { + // Revoked return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ - CertSerial: serial, - CertStatus: 0, // good - ThisUpdate: now, - NextUpdate: now.Add(1 * time.Hour), + CertSerial: serial, + CertStatus: 1, // revoked + RevokedAt: rev.RevokedAt, + RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)), + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), }) } - // Revoked + // Not revoked. Per RFC 6960 §2.2, we must only return "good" for a + // certificate that was actually issued by this CA. Verify the + // (issuer_id, serial) tuple maps to a real certificate in inventory + // before asserting "good"; otherwise return "unknown". This closes the + // coverage gap where forged/guessed serials would be accepted as valid + // because they had no revocation row (M-004). + if s.certRepo != nil { + cert, certErr := s.certRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) + if certErr != nil || cert == nil { + if certErr != nil && !errors.Is(certErr, sql.ErrNoRows) { + // Real repository failure — log but still fail closed with "unknown" + // rather than leaking a bogus "good" assertion. + slog.Warn("OCSP cert lookup failed; returning unknown", + "issuer_id", issuerID, + "serial", serialHex, + "error", certErr) + } + return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ + CertSerial: serial, + CertStatus: 2, // unknown + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), + }) + } + } + + // Known cert, not revoked — return "good" return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ - CertSerial: serial, - CertStatus: 1, // revoked - RevokedAt: rev.RevokedAt, - RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)), - ThisUpdate: now, - NextUpdate: now.Add(1 * time.Hour), + CertSerial: serial, + CertStatus: 0, // good + ThisUpdate: now, + NextUpdate: now.Add(1 * time.Hour), }) } diff --git a/internal/service/ca_operations_test.go b/internal/service/ca_operations_test.go index 2825fd5..2d8990d 100644 --- a/internal/service/ca_operations_test.go +++ b/internal/service/ca_operations_test.go @@ -13,16 +13,25 @@ import ( // helper to create a CAOperationsSvc for testing func newCAOperationsSvcTest() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo) { + caSvc, revocationRepo, certRepo, _ := newCAOperationsSvcTestWithIssuer() + return caSvc, revocationRepo, certRepo +} + +// newCAOperationsSvcTestWithIssuer also returns the mock issuer connector +// so tests can assert on the captured OCSPSignRequest. +func newCAOperationsSvcTestWithIssuer() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo, *mockIssuerConnector) { revocationRepo := newMockRevocationRepository() certRepo := newMockCertificateRepository() profileRepo := newMockProfileRepository() caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo) registry := NewIssuerRegistry(slog.Default()) - registry.Set("iss-local", &mockIssuerConnector{}) + issuer := &mockIssuerConnector{} + registry.Set("iss-local", issuer) + registry.Set("iss-other", &mockIssuerConnector{}) caSvc.SetIssuerRegistry(registry) - return caSvc, revocationRepo, certRepo + return caSvc, revocationRepo, certRepo, issuer } func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) { @@ -126,6 +135,77 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) { t.Logf("OCSP response for good cert generated: %d bytes", len(resp)) } +// TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer guards the M-004 fix: +// a cert with the queried serial exists but under a *different* issuer. Before +// the fix, OCSP fell through to "good" (CertStatus 0) because no revocation row +// matched the (issuer_id, serial) tuple. Per RFC 5280 §5.2.3 serials are unique +// only within a single issuer, and per RFC 6960 §2.2 unknown certs must report +// "unknown" (CertStatus 2), not "good". +func TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer(t *testing.T) { + caSvc, _, certRepo, issuer := newCAOperationsSvcTestWithIssuer() + + // Real cert exists, but bound to iss-other (not iss-local). + cert := &domain.ManagedCertificate{ + ID: "cert-cross-issuer", + CommonName: "cross.example.com", + IssuerID: "iss-other", + Status: domain.CertificateStatusActive, + ExpiresAt: time.Now().AddDate(1, 0, 0), + } + certRepo.AddCert(cert) + certRepo.Versions["cert-cross-issuer"] = []*domain.CertificateVersion{{ + ID: "ver-cross-issuer", + CertificateID: "cert-cross-issuer", + SerialNumber: "CROSS-ISSUER-001", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + CreatedAt: time.Now(), + }} + + // Query OCSP for iss-local + CROSS-ISSUER-001. The serial exists, but + // under iss-other — our JOIN-scoped lookup should return no match. + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "CROSS-ISSUER-001") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if resp == nil || len(resp) == 0 { + t.Fatal("expected non-empty OCSP response") + } + + if issuer.LastOCSPSignRequest == nil { + t.Fatal("expected SignOCSPResponse to be called") + } + if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want { + t.Errorf("CertStatus = %d, want %d (unknown) — cross-issuer lookup must not return good", got, want) + } +} + +// TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial guards the M-004 fix +// for the "forged/guessed serial" case: no certificate exists at this +// (issuer_id, serial) tuple anywhere in inventory. Per RFC 6960 §2.2 we must +// report "unknown" (CertStatus 2), never "good" — returning good for a serial +// we never issued is a protocol violation that would allow an attacker to get +// certctl to vouch for a cert it never signed. +func TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial(t *testing.T) { + caSvc, _, _, issuer := newCAOperationsSvcTestWithIssuer() + + // No cert rows added. Query for an arbitrary serial under iss-local. + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "DEADBEEF-NEVER-ISSUED") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if resp == nil || len(resp) == 0 { + t.Fatal("expected non-empty OCSP response") + } + + if issuer.LastOCSPSignRequest == nil { + t.Fatal("expected SignOCSPResponse to be called") + } + if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want { + t.Errorf("CertStatus = %d, want %d (unknown) — unissued serials must not return good", got, want) + } +} + func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) { caSvc, revocationRepo, certRepo := newCAOperationsSvcTest() diff --git a/internal/service/discovery.go b/internal/service/discovery.go index be73d49..a5ed3b7 100644 --- a/internal/service/discovery.go +++ b/internal/service/discovery.go @@ -148,7 +148,14 @@ func (s *DiscoveryService) GetDiscovered(ctx context.Context, id string) (*domai } // ClaimDiscovered links a discovered certificate to a managed certificate. -func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { +// The actor parameter names the authenticated identity that initiated the +// claim and is recorded on the audit event. Callers in the handler layer pass +// resolveActor(ctx); service-to-service callers pass a descriptive sentinel +// (e.g., "system"). Empty actor falls back to "api" (the same safe sentinel +// resolveActor uses when no auth context is present), never to "operator" — +// hardcoding "operator" was M-005, a coverage-gap closure where audit records +// failed to identify who actually performed the triage action. +func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error { if managedCertID == "" { return fmt.Errorf("managed_certificate_id is required") } @@ -168,8 +175,12 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag return fmt.Errorf("failed to update discovered certificate status: %w", err) } + if actor == "" { + actor = "api" + } + // Audit trail - if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, + if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "discovery_cert_claimed", "discovered_certificate", id, map[string]interface{}{ "managed_certificate_id": managedCertID, @@ -182,14 +193,19 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag return nil } -// DismissDiscovered marks a discovered certificate as dismissed. -func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string) error { +// DismissDiscovered marks a discovered certificate as dismissed. See +// ClaimDiscovered for the actor contract — same rules apply (M-005). +func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error { if err := s.discoveryRepo.UpdateDiscoveredStatus(ctx, id, domain.DiscoveryStatusDismissed, ""); err != nil { return fmt.Errorf("failed to dismiss discovered certificate: %w", err) } + if actor == "" { + actor = "api" + } + // Audit trail - if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, + if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "discovery_cert_dismissed", "discovered_certificate", id, nil); err != nil { slog.Error("failed to record audit event", "error", err) } diff --git a/internal/service/discovery_test.go b/internal/service/discovery_test.go index f392dfe..548738b 100644 --- a/internal/service/discovery_test.go +++ b/internal/service/discovery_test.go @@ -381,7 +381,7 @@ func TestClaimDiscovered_Success(t *testing.T) { } certRepo.AddCert(managedCert) - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -423,7 +423,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "", "test-actor") if err == nil { t.Fatal("expected error for empty managed_certificate_id") } @@ -442,7 +442,7 @@ func TestClaimDiscovered_ManagedCertNotFound(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert") + err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert", "test-actor") if err == nil { t.Fatal("expected error for nonexistent managed certificate") } @@ -464,7 +464,7 @@ func TestDismissDiscovered_Success(t *testing.T) { } discoveryRepo.Discovered[cert.ID] = cert - err := svc.DismissDiscovered(context.Background(), "dcert-1") + err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -497,8 +497,178 @@ func TestDismissDiscovered_NotFound(t *testing.T) { svc, discoveryRepo, _, _ := newDiscoveryTestService() discoveryRepo.UpdateStatusErr = errNotFound - err := svc.DismissDiscovered(context.Background(), "nonexistent") + err := svc.DismissDiscovered(context.Background(), "nonexistent", "test-actor") if err == nil { t.Fatal("expected error for nonexistent cert") } } + +// M-005 regression: caller-supplied actor must propagate onto the +// discovery_cert_claimed audit event so the trail identifies who performed +// triage (pre-M-005 the service hardcoded "operator"). +func TestDiscoveryService_ClaimDiscovered_AuditActor(t *testing.T) { + svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService() + + now := time.Now() + discoveredCert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + FingerprintSHA256: "abc123", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert + + managedCert := &domain.ManagedCertificate{ + ID: "mc-prod-1", + CommonName: "example.com", + Status: domain.CertificateStatusActive, + CreatedAt: now, + UpdatedAt: now, + } + certRepo.AddCert(managedCert) + + if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp"); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + // Locate the discovery_cert_claimed audit event and assert actor propagation. + var claimEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_claimed" { + claimEvent = e + break + } + } + if claimEvent == nil { + t.Fatal("expected discovery_cert_claimed audit event to be recorded") + } + if claimEvent.Actor != "alice@corp" { + t.Errorf("expected audit actor to be caller-supplied 'alice@corp', got %q", claimEvent.Actor) + } + if claimEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression symmetric pair for DismissDiscovered. +func TestDiscoveryService_DismissDiscovered_AuditActor(t *testing.T) { + svc, discoveryRepo, _, auditRepo := newDiscoveryTestService() + + now := time.Now() + cert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[cert.ID] = cert + + if err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp"); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var dismissEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_dismissed" { + dismissEvent = e + break + } + } + if dismissEvent == nil { + t.Fatal("expected discovery_cert_dismissed audit event to be recorded") + } + if dismissEvent.Actor != "bob@corp" { + t.Errorf("expected audit actor to be caller-supplied 'bob@corp', got %q", dismissEvent.Actor) + } + if dismissEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression: when the caller passes an empty actor (e.g., the handler's +// resolveActor helper returns "" because no auth context is present), the +// service must fall back to the safe sentinel "api" — never to the pre-M-005 +// hardcoded "operator". +func TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI(t *testing.T) { + svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService() + + now := time.Now() + discoveredCert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + FingerprintSHA256: "abc123", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert + + managedCert := &domain.ManagedCertificate{ + ID: "mc-prod-1", + CommonName: "example.com", + Status: domain.CertificateStatusActive, + CreatedAt: now, + UpdatedAt: now, + } + certRepo.AddCert(managedCert) + + if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", ""); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var claimEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_claimed" { + claimEvent = e + break + } + } + if claimEvent == nil { + t.Fatal("expected discovery_cert_claimed audit event to be recorded") + } + if claimEvent.Actor != "api" { + t.Errorf("expected empty actor to fall back to 'api', got %q", claimEvent.Actor) + } + if claimEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} + +// M-005 regression symmetric pair for DismissDiscovered empty-actor fallback. +func TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI(t *testing.T) { + svc, discoveryRepo, _, auditRepo := newDiscoveryTestService() + + now := time.Now() + cert := &domain.DiscoveredCertificate{ + ID: "dcert-1", + CommonName: "example.com", + Status: domain.DiscoveryStatusUnmanaged, + CreatedAt: now, + UpdatedAt: now, + } + discoveryRepo.Discovered[cert.ID] = cert + + if err := svc.DismissDiscovered(context.Background(), "dcert-1", ""); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + var dismissEvent *domain.AuditEvent + for _, e := range auditRepo.Events { + if e.Action == "discovery_cert_dismissed" { + dismissEvent = e + break + } + } + if dismissEvent == nil { + t.Fatal("expected discovery_cert_dismissed audit event to be recorded") + } + if dismissEvent.Actor != "api" { + t.Errorf("expected empty actor to fall back to 'api', got %q", dismissEvent.Actor) + } + if dismissEvent.Actor == "operator" { + t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)") + } +} diff --git a/internal/service/shortlived_test.go b/internal/service/shortlived_test.go index db099ee..d64a188 100644 --- a/internal/service/shortlived_test.go +++ b/internal/service/shortlived_test.go @@ -198,6 +198,10 @@ func (m *mockCertRepoWithGetError) GetLatestVersion(ctx context.Context, certID return nil, nil } +func (m *mockCertRepoWithGetError) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + return nil, nil +} + func (m *mockCertRepoWithGetError) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { return nil, m.GetExpiringCertificatesErr } diff --git a/internal/service/testutil_test.go b/internal/service/testutil_test.go index b913fcc..51633d4 100644 --- a/internal/service/testutil_test.go +++ b/internal/service/testutil_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "errors" "sync" "time" @@ -129,6 +130,26 @@ func (m *mockCertRepo) GetLatestVersion(ctx context.Context, certID string) (*do return versions[len(versions)-1], nil } +// GetByIssuerAndSerial emulates the PostgreSQL JOIN: +// SELECT mc.* FROM managed_certificates mc JOIN certificate_versions cv +// ON cv.certificate_id = mc.id WHERE mc.issuer_id = $1 AND cv.serial_number = $2. +// Returns sql.ErrNoRows (the sentinel the real repo surfaces) when no match +// exists, so callers that branch on errors.Is(err, sql.ErrNoRows) behave the +// same in-memory as they do against PostgreSQL. +func (m *mockCertRepo) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) { + for _, cert := range m.Certs { + if cert.IssuerID != issuerID { + continue + } + for _, v := range m.Versions[cert.ID] { + if v.SerialNumber == serial { + return cert, nil + } + } + } + return nil, sql.ErrNoRows +} + func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) { m.Certs[cert.ID] = cert } @@ -784,6 +805,9 @@ type mockIssuerConnector struct { Err error getRenewalInfoResult *RenewalInfoResult getRenewalInfoErr error + // LastOCSPSignRequest captures the last request passed to SignOCSPResponse. + // Tests use this to assert CertStatus (0=good, 1=revoked, 2=unknown). + LastOCSPSignRequest *OCSPSignRequest } func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) { @@ -825,6 +849,9 @@ func (m *mockIssuerConnector) GenerateCRL(ctx context.Context, entries []CRLEntr } func (m *mockIssuerConnector) SignOCSPResponse(ctx context.Context, req OCSPSignRequest) ([]byte, error) { + // Capture the request for test assertions (e.g., CertStatus verification) + reqCopy := req + m.LastOCSPSignRequest = &reqCopy if m.Err != nil { return nil, m.Err }