Close M-004 (OCSP issuer binding) and M-005 (discovery actor propagation) coverage-gap findings

M-004 — OCSP issuer binding (composite key):
  The OCSP lookup path now binds (issuer_id, serial) as a composite key
  rather than resolving by serial alone. CertificateRepository and
  RevocationRepository gain GetByIssuerAndSerial methods; ca_operations.go
  scopes both lookups by the issuer_id path param. When no managed cert
  binds to that (issuer, serial) tuple, GetOCSPResponse constructs an
  RFC 6960 §2.2 'unknown' response (CertStatus=2) instead of the prior
  default 'good'. Short-lived cert exemption (profile TTL < 1h) is
  preserved. Real repo errors (non-sql.ErrNoRows) fail closed with a log.

  Regression coverage: internal/service/ca_operations_test.go
    - TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer
    - TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial

M-005 — Discovery Claim/Dismiss actor propagation:
  DiscoveryService.ClaimDiscovered and DismissDiscovered now accept an
  explicit 'actor string' parameter (propagation pattern mirrors
  bulk_revocation.go / revocation_svc.go). The handler layer passes
  resolveActor(r.Context()) — the named-key identity established by the
  M-002 auth unification — and the service falls back to 'api' (the same
  safe sentinel resolveActor uses when no auth context is present) only
  when the caller passes an empty string. Never falls back to 'operator'.

  Regression coverage: internal/service/discovery_test.go
    - TestDiscoveryService_ClaimDiscovered_AuditActor
    - TestDiscoveryService_DismissDiscovered_AuditActor
    - TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI
    - TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI

Each new test asserts event.Actor matches the caller-supplied string (or
'api' on empty input) and explicitly asserts event.Actor != 'operator'
to lock in the historical fix intent.

Files:
  internal/api/handler/discovery.go          — pass resolveActor(ctx)
  internal/api/handler/discovery_handler_test.go — updated call sites
  internal/integration/lifecycle_test.go     — updated mock wiring
  internal/repository/interfaces.go          — GetByIssuerAndSerial on
                                               CertificateRepository +
                                               RevocationRepository
  internal/repository/postgres/certificate.go — composite key lookup
  internal/service/ca_operations.go          — (issuer_id, serial) scoping
  internal/service/ca_operations_test.go     — 2 new M-004 tests
  internal/service/discovery.go              — actor parameter + 'api' fallback
  internal/service/discovery_test.go         — 4 new M-005 tests
  internal/service/shortlived_test.go        — mock signature update
  internal/service/testutil_test.go          — mock GetByIssuerAndSerial
This commit is contained in:
shankar0123
2026-04-18 22:20:25 +00:00
parent ff7357f889
commit fe7e766510
11 changed files with 430 additions and 41 deletions
+9 -4
View File
@@ -11,12 +11,17 @@ import (
) )
// DiscoveryService defines the interface used by the discovery handler. // DiscoveryService defines the interface used by the discovery handler.
// ClaimDiscovered and DismissDiscovered accept an explicit actor parameter so
// the handler can flow the authenticated named-key identity into the audit
// trail (M-005). Services that call these methods from non-request contexts
// pass a descriptive sentinel (e.g., "system") or "" (which falls back to
// "api").
type DiscoveryService interface { type DiscoveryService interface {
ProcessDiscoveryReport(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ProcessDiscoveryReport(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error)
ListDiscovered(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) ListDiscovered(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error)
GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error)
ClaimDiscovered(ctx context.Context, id string, managedCertID string) error ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error
DismissDiscovered(ctx context.Context, id string) error DismissDiscovered(ctx context.Context, id string, actor string) error
ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error)
GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error)
GetDiscoverySummary(ctx context.Context) (map[string]int, error) GetDiscoverySummary(ctx context.Context) (map[string]int, error)
@@ -142,7 +147,7 @@ func (h DiscoveryHandler) ClaimDiscovered(w http.ResponseWriter, r *http.Request
return return
} }
if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID); err != nil { if err := h.svc.ClaimDiscovered(r.Context(), id, body.ManagedCertificateID, resolveActor(r.Context())); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to claim certificate: %v", err)) Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to claim certificate: %v", err))
return return
} }
@@ -166,7 +171,7 @@ func (h DiscoveryHandler) DismissDiscovered(w http.ResponseWriter, r *http.Reque
return return
} }
if err := h.svc.DismissDiscovered(r.Context(), id); err != nil { if err := h.svc.DismissDiscovered(r.Context(), id, resolveActor(r.Context())); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to dismiss certificate: %v", err)) Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to dismiss certificate: %v", err))
return return
} }
+10 -10
View File
@@ -19,8 +19,8 @@ type MockDiscoveryService struct {
ProcessDiscoveryReportFn func(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error) ProcessDiscoveryReportFn func(ctx context.Context, report *domain.DiscoveryReport) (*domain.DiscoveryScan, error)
ListDiscoveredFn func(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error) ListDiscoveredFn func(ctx context.Context, agentID, status string, page, perPage int) ([]*domain.DiscoveredCertificate, int, error)
GetDiscoveredFn func(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) GetDiscoveredFn func(ctx context.Context, id string) (*domain.DiscoveredCertificate, error)
ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string) error ClaimDiscoveredFn func(ctx context.Context, id string, managedCertID string, actor string) error
DismissDiscoveredFn func(ctx context.Context, id string) error DismissDiscoveredFn func(ctx context.Context, id string, actor string) error
ListScansFn func(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) ListScansFn func(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error)
GetScanFn func(ctx context.Context, id string) (*domain.DiscoveryScan, error) GetScanFn func(ctx context.Context, id string) (*domain.DiscoveryScan, error)
GetDiscoverySummaryFn func(ctx context.Context) (map[string]int, error) GetDiscoverySummaryFn func(ctx context.Context) (map[string]int, error)
@@ -47,16 +47,16 @@ func (m *MockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d
return nil, nil return nil, nil
} }
func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { func (m *MockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error {
if m.ClaimDiscoveredFn != nil { if m.ClaimDiscoveredFn != nil {
return m.ClaimDiscoveredFn(ctx, id, managedCertID) return m.ClaimDiscoveredFn(ctx, id, managedCertID, actor)
} }
return nil return nil
} }
func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { func (m *MockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error {
if m.DismissDiscoveredFn != nil { if m.DismissDiscoveredFn != nil {
return m.DismissDiscoveredFn(ctx, id) return m.DismissDiscoveredFn(ctx, id, actor)
} }
return nil return nil
} }
@@ -352,7 +352,7 @@ func TestGetDiscovered_NotFound(t *testing.T) {
// Test ClaimDiscovered - success case // Test ClaimDiscovered - success case
func TestClaimDiscovered_Success(t *testing.T) { func TestClaimDiscovered_Success(t *testing.T) {
mock := &MockDiscoveryService{ mock := &MockDiscoveryService{
ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error {
if id == "dcert-1" && managedCertID == "mc-prod-1" { if id == "dcert-1" && managedCertID == "mc-prod-1" {
return nil return nil
} }
@@ -411,7 +411,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) {
// Test ClaimDiscovered - discovered cert not found // Test ClaimDiscovered - discovered cert not found
func TestClaimDiscovered_NotFound(t *testing.T) { func TestClaimDiscovered_NotFound(t *testing.T) {
mock := &MockDiscoveryService{ mock := &MockDiscoveryService{
ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string) error { ClaimDiscoveredFn: func(ctx context.Context, id string, managedCertID string, actor string) error {
return fmt.Errorf("discovered certificate not found") return fmt.Errorf("discovered certificate not found")
}, },
} }
@@ -438,7 +438,7 @@ func TestClaimDiscovered_NotFound(t *testing.T) {
// Test DismissDiscovered - success case // Test DismissDiscovered - success case
func TestDismissDiscovered_Success(t *testing.T) { func TestDismissDiscovered_Success(t *testing.T) {
mock := &MockDiscoveryService{ mock := &MockDiscoveryService{
DismissDiscoveredFn: func(ctx context.Context, id string) error { DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error {
if id == "dcert-1" { if id == "dcert-1" {
return nil return nil
} }
@@ -614,7 +614,7 @@ func TestGetDiscoverySummary_MethodNotAllowed(t *testing.T) {
// Test DismissDiscovered - service error // Test DismissDiscovered - service error
func TestDismissDiscovered_ServiceError(t *testing.T) { func TestDismissDiscovered_ServiceError(t *testing.T) {
mock := &MockDiscoveryService{ mock := &MockDiscoveryService{
DismissDiscoveredFn: func(ctx context.Context, id string) error { DismissDiscoveredFn: func(ctx context.Context, id string, actor string) error {
return fmt.Errorf("database error") return fmt.Errorf("database error")
}, },
} }
+21 -2
View File
@@ -3,6 +3,7 @@ package integration
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -586,6 +587,24 @@ func (m *mockCertificateRepository) GetLatestVersion(ctx context.Context, certID
return versions[len(versions)-1], nil return versions[len(versions)-1], nil
} }
// GetByIssuerAndSerial emulates the PostgreSQL JOIN that scopes cert lookup to
// (issuer_id, serial). Returns sql.ErrNoRows when no match exists so callers
// that branch on errors.Is(err, sql.ErrNoRows) (notably the OCSP handler's
// M-004 "unknown" fallback) behave the same in-memory as against PostgreSQL.
func (m *mockCertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) {
for _, cert := range m.certs {
if cert.IssuerID != issuerID {
continue
}
for _, v := range m.versions[cert.ID] {
if v.SerialNumber == serial {
return cert, nil
}
}
}
return nil, sql.ErrNoRows
}
type mockJobRepository struct { type mockJobRepository struct {
jobs map[string]*domain.Job jobs map[string]*domain.Job
} }
@@ -1301,11 +1320,11 @@ func (m *mockDiscoveryService) GetDiscovered(ctx context.Context, id string) (*d
return nil, fmt.Errorf("not found") return nil, fmt.Errorf("not found")
} }
func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { func (m *mockDiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error {
return nil return nil
} }
func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string) error { func (m *mockDiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error {
return nil return nil
} }
+7
View File
@@ -27,6 +27,13 @@ type CertificateRepository interface {
GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error)
// GetLatestVersion returns the most recent certificate version for a certificate. // GetLatestVersion returns the most recent certificate version for a certificate.
GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error)
// GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number)
// pair via a JOIN on certificate_versions. Callers (OCSP, revocation lookup)
// always know the issuer because protocol endpoints carry it in the request
// path; RFC 5280 §5.2.3 guarantees serial uniqueness only within a single
// issuer. Returns sql.ErrNoRows when no match exists so callers can
// distinguish "unknown cert" from a real repository error.
GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error)
} }
// RevocationRepository defines operations for managing certificate revocations. // RevocationRepository defines operations for managing certificate revocations.
@@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -272,6 +273,38 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man
return cert, nil return cert, nil
} }
// GetByIssuerAndSerial retrieves a certificate by the (issuer_id, serial_number)
// pair via a JOIN on certificate_versions. Per RFC 5280 §5.2.3, serial numbers
// are unique only within a single issuer — callers that know the issuer (OCSP,
// CRL generation, revocation lookup) use this method to scope lookups
// correctly. Returns sql.ErrNoRows when no match exists so callers can
// distinguish "unknown cert" (return OCSP status unknown) from a real
// repository error.
func (r *CertificateRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) {
row := r.db.QueryRowContext(ctx, `
SELECT mc.id, mc.name, mc.common_name, mc.sans, mc.environment, mc.owner_id, mc.team_id,
mc.issuer_id, mc.renewal_policy_id, mc.certificate_profile_id, mc.status, mc.expires_at,
mc.tags, mc.last_renewal_at, mc.last_deployment_at, mc.revoked_at, mc.revocation_reason,
mc.created_at, mc.updated_at
FROM managed_certificates mc
JOIN certificate_versions cv ON cv.certificate_id = mc.id
WHERE mc.issuer_id = $1 AND cv.serial_number = $2
LIMIT 1
`, issuerID, serial)
cert, err := r.scanCertificate(ctx, row)
if err != nil {
// scanCertificate wraps sql.ErrNoRows via %w, so surface the bare
// sentinel here for callers that branch on it with errors.Is.
if errors.Is(err, sql.ErrNoRows) {
return nil, sql.ErrNoRows
}
return nil, fmt.Errorf("failed to query certificate by issuer+serial: %w", err)
}
return cert, nil
}
// Create stores a new certificate // Create stores a new certificate
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error { func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
if cert.ID == "" { if cert.ID == "" {
+41 -13
View File
@@ -2,6 +2,8 @@ package service
import ( import (
"context" "context"
"database/sql"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math/big" "math/big"
@@ -139,23 +141,49 @@ func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string,
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping. // Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if err != nil { if err == nil && rev != nil {
// Not revoked — return "good" status // Revoked
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial, CertSerial: serial,
CertStatus: 0, // good CertStatus: 1, // revoked
ThisUpdate: now, RevokedAt: rev.RevokedAt,
NextUpdate: now.Add(1 * time.Hour), RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)),
ThisUpdate: now,
NextUpdate: now.Add(1 * time.Hour),
}) })
} }
// Revoked // Not revoked. Per RFC 6960 §2.2, we must only return "good" for a
// certificate that was actually issued by this CA. Verify the
// (issuer_id, serial) tuple maps to a real certificate in inventory
// before asserting "good"; otherwise return "unknown". This closes the
// coverage gap where forged/guessed serials would be accepted as valid
// because they had no revocation row (M-004).
if s.certRepo != nil {
cert, certErr := s.certRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if certErr != nil || cert == nil {
if certErr != nil && !errors.Is(certErr, sql.ErrNoRows) {
// Real repository failure — log but still fail closed with "unknown"
// rather than leaking a bogus "good" assertion.
slog.Warn("OCSP cert lookup failed; returning unknown",
"issuer_id", issuerID,
"serial", serialHex,
"error", certErr)
}
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial,
CertStatus: 2, // unknown
ThisUpdate: now,
NextUpdate: now.Add(1 * time.Hour),
})
}
}
// Known cert, not revoked — return "good"
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial, CertSerial: serial,
CertStatus: 1, // revoked CertStatus: 0, // good
RevokedAt: rev.RevokedAt, ThisUpdate: now,
RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)), NextUpdate: now.Add(1 * time.Hour),
ThisUpdate: now,
NextUpdate: now.Add(1 * time.Hour),
}) })
} }
+82 -2
View File
@@ -13,16 +13,25 @@ import (
// helper to create a CAOperationsSvc for testing // helper to create a CAOperationsSvc for testing
func newCAOperationsSvcTest() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo) { func newCAOperationsSvcTest() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo) {
caSvc, revocationRepo, certRepo, _ := newCAOperationsSvcTestWithIssuer()
return caSvc, revocationRepo, certRepo
}
// newCAOperationsSvcTestWithIssuer also returns the mock issuer connector
// so tests can assert on the captured OCSPSignRequest.
func newCAOperationsSvcTestWithIssuer() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo, *mockIssuerConnector) {
revocationRepo := newMockRevocationRepository() revocationRepo := newMockRevocationRepository()
certRepo := newMockCertificateRepository() certRepo := newMockCertificateRepository()
profileRepo := newMockProfileRepository() profileRepo := newMockProfileRepository()
caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo) caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
registry.Set("iss-local", &mockIssuerConnector{}) issuer := &mockIssuerConnector{}
registry.Set("iss-local", issuer)
registry.Set("iss-other", &mockIssuerConnector{})
caSvc.SetIssuerRegistry(registry) caSvc.SetIssuerRegistry(registry)
return caSvc, revocationRepo, certRepo return caSvc, revocationRepo, certRepo, issuer
} }
func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) { func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) {
@@ -126,6 +135,77 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) {
t.Logf("OCSP response for good cert generated: %d bytes", len(resp)) t.Logf("OCSP response for good cert generated: %d bytes", len(resp))
} }
// TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer guards the M-004 fix:
// a cert with the queried serial exists but under a *different* issuer. Before
// the fix, OCSP fell through to "good" (CertStatus 0) because no revocation row
// matched the (issuer_id, serial) tuple. Per RFC 5280 §5.2.3 serials are unique
// only within a single issuer, and per RFC 6960 §2.2 unknown certs must report
// "unknown" (CertStatus 2), not "good".
func TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer(t *testing.T) {
caSvc, _, certRepo, issuer := newCAOperationsSvcTestWithIssuer()
// Real cert exists, but bound to iss-other (not iss-local).
cert := &domain.ManagedCertificate{
ID: "cert-cross-issuer",
CommonName: "cross.example.com",
IssuerID: "iss-other",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(1, 0, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-cross-issuer"] = []*domain.CertificateVersion{{
ID: "ver-cross-issuer",
CertificateID: "cert-cross-issuer",
SerialNumber: "CROSS-ISSUER-001",
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
CreatedAt: time.Now(),
}}
// Query OCSP for iss-local + CROSS-ISSUER-001. The serial exists, but
// under iss-other — our JOIN-scoped lookup should return no match.
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "CROSS-ISSUER-001")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if resp == nil || len(resp) == 0 {
t.Fatal("expected non-empty OCSP response")
}
if issuer.LastOCSPSignRequest == nil {
t.Fatal("expected SignOCSPResponse to be called")
}
if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want {
t.Errorf("CertStatus = %d, want %d (unknown) — cross-issuer lookup must not return good", got, want)
}
}
// TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial guards the M-004 fix
// for the "forged/guessed serial" case: no certificate exists at this
// (issuer_id, serial) tuple anywhere in inventory. Per RFC 6960 §2.2 we must
// report "unknown" (CertStatus 2), never "good" — returning good for a serial
// we never issued is a protocol violation that would allow an attacker to get
// certctl to vouch for a cert it never signed.
func TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial(t *testing.T) {
caSvc, _, _, issuer := newCAOperationsSvcTestWithIssuer()
// No cert rows added. Query for an arbitrary serial under iss-local.
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "DEADBEEF-NEVER-ISSUED")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if resp == nil || len(resp) == 0 {
t.Fatal("expected non-empty OCSP response")
}
if issuer.LastOCSPSignRequest == nil {
t.Fatal("expected SignOCSPResponse to be called")
}
if got, want := issuer.LastOCSPSignRequest.CertStatus, 2; got != want {
t.Errorf("CertStatus = %d, want %d (unknown) — unissued serials must not return good", got, want)
}
}
func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) { func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) {
caSvc, revocationRepo, certRepo := newCAOperationsSvcTest() caSvc, revocationRepo, certRepo := newCAOperationsSvcTest()
+21 -5
View File
@@ -148,7 +148,14 @@ func (s *DiscoveryService) GetDiscovered(ctx context.Context, id string) (*domai
} }
// ClaimDiscovered links a discovered certificate to a managed certificate. // ClaimDiscovered links a discovered certificate to a managed certificate.
func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string) error { // The actor parameter names the authenticated identity that initiated the
// claim and is recorded on the audit event. Callers in the handler layer pass
// resolveActor(ctx); service-to-service callers pass a descriptive sentinel
// (e.g., "system"). Empty actor falls back to "api" (the same safe sentinel
// resolveActor uses when no auth context is present), never to "operator" —
// hardcoding "operator" was M-005, a coverage-gap closure where audit records
// failed to identify who actually performed the triage action.
func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, managedCertID string, actor string) error {
if managedCertID == "" { if managedCertID == "" {
return fmt.Errorf("managed_certificate_id is required") return fmt.Errorf("managed_certificate_id is required")
} }
@@ -168,8 +175,12 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag
return fmt.Errorf("failed to update discovered certificate status: %w", err) return fmt.Errorf("failed to update discovered certificate status: %w", err)
} }
if actor == "" {
actor = "api"
}
// Audit trail // Audit trail
if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
"discovery_cert_claimed", "discovered_certificate", id, "discovery_cert_claimed", "discovered_certificate", id,
map[string]interface{}{ map[string]interface{}{
"managed_certificate_id": managedCertID, "managed_certificate_id": managedCertID,
@@ -182,14 +193,19 @@ func (s *DiscoveryService) ClaimDiscovered(ctx context.Context, id string, manag
return nil return nil
} }
// DismissDiscovered marks a discovered certificate as dismissed. // DismissDiscovered marks a discovered certificate as dismissed. See
func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string) error { // ClaimDiscovered for the actor contract — same rules apply (M-005).
func (s *DiscoveryService) DismissDiscovered(ctx context.Context, id string, actor string) error {
if err := s.discoveryRepo.UpdateDiscoveredStatus(ctx, id, domain.DiscoveryStatusDismissed, ""); err != nil { if err := s.discoveryRepo.UpdateDiscoveredStatus(ctx, id, domain.DiscoveryStatusDismissed, ""); err != nil {
return fmt.Errorf("failed to dismiss discovered certificate: %w", err) return fmt.Errorf("failed to dismiss discovered certificate: %w", err)
} }
if actor == "" {
actor = "api"
}
// Audit trail // Audit trail
if err := s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser, if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
"discovery_cert_dismissed", "discovered_certificate", id, nil); err != nil { "discovery_cert_dismissed", "discovered_certificate", id, nil); err != nil {
slog.Error("failed to record audit event", "error", err) slog.Error("failed to record audit event", "error", err)
} }
+175 -5
View File
@@ -381,7 +381,7 @@ func TestClaimDiscovered_Success(t *testing.T) {
} }
certRepo.AddCert(managedCert) certRepo.AddCert(managedCert)
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1") err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -423,7 +423,7 @@ func TestClaimDiscovered_MissingManagedCertID(t *testing.T) {
} }
discoveryRepo.Discovered[cert.ID] = cert discoveryRepo.Discovered[cert.ID] = cert
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "") err := svc.ClaimDiscovered(context.Background(), "dcert-1", "", "test-actor")
if err == nil { if err == nil {
t.Fatal("expected error for empty managed_certificate_id") t.Fatal("expected error for empty managed_certificate_id")
} }
@@ -442,7 +442,7 @@ func TestClaimDiscovered_ManagedCertNotFound(t *testing.T) {
} }
discoveryRepo.Discovered[cert.ID] = cert discoveryRepo.Discovered[cert.ID] = cert
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert") err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert", "test-actor")
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent managed certificate") t.Fatal("expected error for nonexistent managed certificate")
} }
@@ -464,7 +464,7 @@ func TestDismissDiscovered_Success(t *testing.T) {
} }
discoveryRepo.Discovered[cert.ID] = cert discoveryRepo.Discovered[cert.ID] = cert
err := svc.DismissDiscovered(context.Background(), "dcert-1") err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -497,8 +497,178 @@ func TestDismissDiscovered_NotFound(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService() svc, discoveryRepo, _, _ := newDiscoveryTestService()
discoveryRepo.UpdateStatusErr = errNotFound discoveryRepo.UpdateStatusErr = errNotFound
err := svc.DismissDiscovered(context.Background(), "nonexistent") err := svc.DismissDiscovered(context.Background(), "nonexistent", "test-actor")
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent cert") t.Fatal("expected error for nonexistent cert")
} }
} }
// M-005 regression: caller-supplied actor must propagate onto the
// discovery_cert_claimed audit event so the trail identifies who performed
// triage (pre-M-005 the service hardcoded "operator").
func TestDiscoveryService_ClaimDiscovered_AuditActor(t *testing.T) {
svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService()
now := time.Now()
discoveredCert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
FingerprintSHA256: "abc123",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert
managedCert := &domain.ManagedCertificate{
ID: "mc-prod-1",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
certRepo.AddCert(managedCert)
if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp"); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Locate the discovery_cert_claimed audit event and assert actor propagation.
var claimEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_claimed" {
claimEvent = e
break
}
}
if claimEvent == nil {
t.Fatal("expected discovery_cert_claimed audit event to be recorded")
}
if claimEvent.Actor != "alice@corp" {
t.Errorf("expected audit actor to be caller-supplied 'alice@corp', got %q", claimEvent.Actor)
}
if claimEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression symmetric pair for DismissDiscovered.
func TestDiscoveryService_DismissDiscovered_AuditActor(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
if err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp"); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var dismissEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_dismissed" {
dismissEvent = e
break
}
}
if dismissEvent == nil {
t.Fatal("expected discovery_cert_dismissed audit event to be recorded")
}
if dismissEvent.Actor != "bob@corp" {
t.Errorf("expected audit actor to be caller-supplied 'bob@corp', got %q", dismissEvent.Actor)
}
if dismissEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression: when the caller passes an empty actor (e.g., the handler's
// resolveActor helper returns "" because no auth context is present), the
// service must fall back to the safe sentinel "api" — never to the pre-M-005
// hardcoded "operator".
func TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI(t *testing.T) {
svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService()
now := time.Now()
discoveredCert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
FingerprintSHA256: "abc123",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert
managedCert := &domain.ManagedCertificate{
ID: "mc-prod-1",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
certRepo.AddCert(managedCert)
if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", ""); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var claimEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_claimed" {
claimEvent = e
break
}
}
if claimEvent == nil {
t.Fatal("expected discovery_cert_claimed audit event to be recorded")
}
if claimEvent.Actor != "api" {
t.Errorf("expected empty actor to fall back to 'api', got %q", claimEvent.Actor)
}
if claimEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression symmetric pair for DismissDiscovered empty-actor fallback.
func TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
if err := svc.DismissDiscovered(context.Background(), "dcert-1", ""); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var dismissEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_dismissed" {
dismissEvent = e
break
}
}
if dismissEvent == nil {
t.Fatal("expected discovery_cert_dismissed audit event to be recorded")
}
if dismissEvent.Actor != "api" {
t.Errorf("expected empty actor to fall back to 'api', got %q", dismissEvent.Actor)
}
if dismissEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
+4
View File
@@ -198,6 +198,10 @@ func (m *mockCertRepoWithGetError) GetLatestVersion(ctx context.Context, certID
return nil, nil return nil, nil
} }
func (m *mockCertRepoWithGetError) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) {
return nil, nil
}
func (m *mockCertRepoWithGetError) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { func (m *mockCertRepoWithGetError) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
return nil, m.GetExpiringCertificatesErr return nil, m.GetExpiringCertificatesErr
} }
+27
View File
@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"sync" "sync"
"time" "time"
@@ -129,6 +130,26 @@ func (m *mockCertRepo) GetLatestVersion(ctx context.Context, certID string) (*do
return versions[len(versions)-1], nil return versions[len(versions)-1], nil
} }
// GetByIssuerAndSerial emulates the PostgreSQL JOIN:
// SELECT mc.* FROM managed_certificates mc JOIN certificate_versions cv
// ON cv.certificate_id = mc.id WHERE mc.issuer_id = $1 AND cv.serial_number = $2.
// Returns sql.ErrNoRows (the sentinel the real repo surfaces) when no match
// exists, so callers that branch on errors.Is(err, sql.ErrNoRows) behave the
// same in-memory as they do against PostgreSQL.
func (m *mockCertRepo) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.ManagedCertificate, error) {
for _, cert := range m.Certs {
if cert.IssuerID != issuerID {
continue
}
for _, v := range m.Versions[cert.ID] {
if v.SerialNumber == serial {
return cert, nil
}
}
}
return nil, sql.ErrNoRows
}
func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) { func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) {
m.Certs[cert.ID] = cert m.Certs[cert.ID] = cert
} }
@@ -784,6 +805,9 @@ type mockIssuerConnector struct {
Err error Err error
getRenewalInfoResult *RenewalInfoResult getRenewalInfoResult *RenewalInfoResult
getRenewalInfoErr error getRenewalInfoErr error
// LastOCSPSignRequest captures the last request passed to SignOCSPResponse.
// Tests use this to assert CertStatus (0=good, 1=revoked, 2=unknown).
LastOCSPSignRequest *OCSPSignRequest
} }
func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) { func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
@@ -825,6 +849,9 @@ func (m *mockIssuerConnector) GenerateCRL(ctx context.Context, entries []CRLEntr
} }
func (m *mockIssuerConnector) SignOCSPResponse(ctx context.Context, req OCSPSignRequest) ([]byte, error) { func (m *mockIssuerConnector) SignOCSPResponse(ctx context.Context, req OCSPSignRequest) ([]byte, error) {
// Capture the request for test assertions (e.g., CertStatus verification)
reqCopy := req
m.LastOCSPSignRequest = &reqCopy
if m.Err != nil { if m.Err != nil {
return nil, m.Err return nil, m.Err
} }