fix(m-2): thread context through CertificateService cluster

Collapses CertificateService, RevocationSvc, and CAOperationsSvc to
ctx-accepting method signatures. Removes context.Background() synthesis
at 24 internal call sites across certificate.go, revocation_svc.go, and
ca_operations.go.

- Primary repo calls inherit request cancellation via the passed ctx.
- Audit and notification dispatches use context.WithoutCancel(ctx) so
  they survive client disconnect.
- Collapses TriggerRenewal/TriggerRenewalWithActor,
  TriggerDeployment/TriggerDeploymentWithActor, and
  RevokeCertificate/RevokeCertificateWithActor sibling pairs into single
  canonical ctx-accepting methods (decisions D-1, D-2).

Handlers pass r.Context(). Mocks and tests updated to match new
signatures. No HTTP surface change, no OpenAPI change.

PR 1 of 6 in the M-2 remediation chain. Master green at this commit.

Refs: certctl-audit-report.md M-2 (L143, L224)
This commit is contained in:
shankar0123
2026-04-18 00:29:37 +00:00
parent e951d319d0
commit cdc9d03d5b
12 changed files with 225 additions and 235 deletions
@@ -27,6 +27,7 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -120,7 +121,7 @@ func TestGetCertificate_PathInjection(t *testing.T) {
handler, mock := newCertHandlerWithMock()
// Force a 404 so we can distinguish "service was called" from
// "parser accepted the ID"; a 200 with null body is also fine.
mock.GetCertificateFn = func(id string) (*domain.ManagedCertificate, error) {
mock.GetCertificateFn = func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound
}
@@ -156,7 +157,7 @@ func TestUpdateCertificate_PathInjection(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.UpdateCertificateFn = func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
mock.UpdateCertificateFn = func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound
}
@@ -184,7 +185,7 @@ func TestArchiveCertificate_PathInjection(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ArchiveCertificateFn = func(id string) error { return ErrMockNotFound }
mock.ArchiveCertificateFn = func(_ context.Context, id string) error { return ErrMockNotFound }
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil)
req.URL.Path = "/api/v1/certificates/" + tc.input
@@ -227,7 +228,7 @@ func TestGetCertificateVersions_MultiSegment(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.GetCertificateVersionsFn = func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
mock.GetCertificateVersionsFn = func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
return []domain.CertificateVersion{}, 0, nil
}
@@ -277,7 +278,7 @@ func TestHandleOCSP_MultiSegment(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.GetOCSPResponseFn = func(issuerID, serialHex string) ([]byte, error) {
mock.GetOCSPResponseFn = func(_ context.Context, issuerID, serialHex string) ([]byte, error) {
return nil, ErrMockNotFound
}
@@ -311,7 +312,7 @@ func TestGetDERCRL_IssuerPathInjection(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.GenerateDERCRLFn = func(issuerID string) ([]byte, error) {
mock.GenerateDERCRLFn = func(_ context.Context, issuerID string) ([]byte, error) {
return nil, ErrMockNotFound
}
+12 -11
View File
@@ -19,6 +19,7 @@ package handler
import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
@@ -76,7 +77,7 @@ func TestListCertificates_PaginationAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Sanity: page/perPage on the filter must never be negative
// and perPage must never exceed 500 after parsing.
if filter.Page < 1 {
@@ -133,7 +134,7 @@ func TestListCertificates_SortAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -175,7 +176,7 @@ func TestListCertificates_FieldsAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -219,7 +220,7 @@ func TestListCertificates_TimeRangeAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -263,7 +264,7 @@ func TestListCertificates_CursorAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -314,7 +315,7 @@ func TestListCertificates_FilterInjection(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -374,7 +375,7 @@ func TestCreateCertificate_BodyAbuse(t *testing.T) {
}()
handler, mock := newCertHandlerWithMock()
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
// If we ever reach this, the handler accepted a malformed
// body. Return a sentinel that passes but flag it.
c := cert
@@ -419,7 +420,7 @@ func TestCreateCertificate_HugeBody(t *testing.T) {
sb.WriteString(`]}`)
handler, mock := newCertHandlerWithMock()
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
c := cert
c.ID = "mc-huge"
return &c, nil
@@ -476,7 +477,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
handler, mock := newCertHandlerWithMock()
// The mock always returns "invalid revocation reason" so we
// verify the handler's errMsg→status mapping turns it into a 400.
mock.RevokeCertificateFn = func(id string, reason string) error {
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
// The service uses domain.IsValidRevocationReason. If we got
// through to here with something bogus, simulate a real
// service error.
@@ -500,7 +501,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
// service error message, which is fragile — this test catches regressions.
func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
handler, mock := newCertHandlerWithMock()
mock.RevokeCertificateFn = func(id string, reason string) error {
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
return fmt.Errorf("cannot revoke: certificate is already revoked")
}
@@ -520,7 +521,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
// TestRevokeCertificate_NotFound verifies 404 mapping.
func TestRevokeCertificate_NotFound(t *testing.T) {
handler, mock := newCertHandlerWithMock()
mock.RevokeCertificateFn = func(id string, reason string) error {
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
return fmt.Errorf("certificate not found")
}
@@ -17,116 +17,116 @@ import (
// MockCertificateService is a mock implementation of CertificateService interface.
type MockCertificateService struct {
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificateFn func(id string) error
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewalFn func(certID string) error
TriggerDeploymentFn func(certID string, targetID string) error
RevokeCertificateFn func(certID string, reason string) error
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error)
GenerateDERCRLFn func(issuerID string) ([]byte, error)
GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error)
GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error)
ListCertificatesFn func(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilterFn func(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificateFn func(ctx context.Context, id string) (*domain.ManagedCertificate, error)
CreateCertificateFn func(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificateFn func(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificateFn func(ctx context.Context, id string) error
GetCertificateVersionsFn func(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewalFn func(ctx context.Context, certID string, actor string) error
TriggerDeploymentFn func(ctx context.Context, certID string, targetID string, actor string) error
RevokeCertificateFn func(ctx context.Context, certID string, reason string, actor string) error
GetRevokedCertificatesFn func(ctx context.Context) ([]*domain.CertificateRevocation, error)
GenerateDERCRLFn func(ctx context.Context, issuerID string) ([]byte, error)
GetOCSPResponseFn func(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
GetCertificateDeploymentsFn func(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
}
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
func (m *MockCertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if m.ListCertificatesFn != nil {
return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage)
return m.ListCertificatesFn(ctx, status, environment, ownerID, teamID, issuerID, page, perPage)
}
return nil, 0, nil
}
func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
func (m *MockCertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
if m.GetCertificateFn != nil {
return m.GetCertificateFn(id)
return m.GetCertificateFn(ctx, id)
}
return nil, nil
}
func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
func (m *MockCertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.CreateCertificateFn != nil {
return m.CreateCertificateFn(cert)
return m.CreateCertificateFn(ctx, cert)
}
return nil, nil
}
func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
func (m *MockCertificateService) UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.UpdateCertificateFn != nil {
return m.UpdateCertificateFn(id, cert)
return m.UpdateCertificateFn(ctx, id, cert)
}
return nil, nil
}
func (m *MockCertificateService) ArchiveCertificate(id string) error {
func (m *MockCertificateService) ArchiveCertificate(ctx context.Context, id string) error {
if m.ArchiveCertificateFn != nil {
return m.ArchiveCertificateFn(id)
return m.ArchiveCertificateFn(ctx, id)
}
return nil
}
func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
func (m *MockCertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if m.GetCertificateVersionsFn != nil {
return m.GetCertificateVersionsFn(certID, page, perPage)
return m.GetCertificateVersionsFn(ctx, certID, page, perPage)
}
return nil, 0, nil
}
func (m *MockCertificateService) TriggerRenewal(certID string) error {
func (m *MockCertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
if m.TriggerRenewalFn != nil {
return m.TriggerRenewalFn(certID)
return m.TriggerRenewalFn(ctx, certID, actor)
}
return nil
}
func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error {
func (m *MockCertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
if m.TriggerDeploymentFn != nil {
return m.TriggerDeploymentFn(certID, targetID)
return m.TriggerDeploymentFn(ctx, certID, targetID, actor)
}
return nil
}
func (m *MockCertificateService) RevokeCertificate(certID string, reason string) error {
func (m *MockCertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
if m.RevokeCertificateFn != nil {
return m.RevokeCertificateFn(certID, reason)
return m.RevokeCertificateFn(ctx, certID, reason, actor)
}
return nil
}
func (m *MockCertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
func (m *MockCertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if m.GetRevokedCertificatesFn != nil {
return m.GetRevokedCertificatesFn()
return m.GetRevokedCertificatesFn(ctx)
}
return nil, nil
}
func (m *MockCertificateService) GenerateDERCRL(issuerID string) ([]byte, error) {
func (m *MockCertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if m.GenerateDERCRLFn != nil {
return m.GenerateDERCRLFn(issuerID)
return m.GenerateDERCRLFn(ctx, issuerID)
}
return nil, nil
}
func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
func (m *MockCertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if m.GetOCSPResponseFn != nil {
return m.GetOCSPResponseFn(issuerID, serialHex)
return m.GetOCSPResponseFn(ctx, issuerID, serialHex)
}
return nil, nil
}
func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
func (m *MockCertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if m.ListCertificatesWithFilterFn != nil {
return m.ListCertificatesWithFilterFn(filter)
return m.ListCertificatesWithFilterFn(ctx, filter)
}
return nil, 0, nil
}
func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
func (m *MockCertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
if m.GetCertificateDeploymentsFn != nil {
return m.GetCertificateDeploymentsFn(certID)
return m.GetCertificateDeploymentsFn(ctx, certID)
}
return nil, nil
}
@@ -158,7 +158,7 @@ func TestListCertificates_Success(t *testing.T) {
}
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Page == 1 && filter.PerPage == 50 {
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
}
@@ -197,7 +197,7 @@ func TestListCertificates_Success(t *testing.T) {
// Test ListCertificates - with filters
func TestListCertificates_WithFilters(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Status == "Active" && filter.Environment == "prod" {
return []domain.ManagedCertificate{}, 0, nil
}
@@ -236,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) {
// Test ListCertificates - service error
func TestListCertificates_ServiceError(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return nil, 0, ErrMockServiceFailed
},
}
@@ -266,7 +266,7 @@ func TestGetCertificate_Success(t *testing.T) {
}
mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" {
return cert, nil
}
@@ -298,7 +298,7 @@ func TestGetCertificate_Success(t *testing.T) {
// Test GetCertificate - not found
func TestGetCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound
},
}
@@ -345,7 +345,7 @@ func TestCreateCertificate_Success(t *testing.T) {
}
mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return created, nil
},
}
@@ -403,7 +403,7 @@ func TestCreateCertificate_InvalidBody(t *testing.T) {
// Test CreateCertificate - service error
func TestCreateCertificate_ServiceError(t *testing.T) {
mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return nil, ErrMockServiceFailed
},
}
@@ -445,7 +445,7 @@ func TestUpdateCertificate_Success(t *testing.T) {
}
mock := &MockCertificateService{
UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
UpdateCertificateFn: func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" {
return updated, nil
}
@@ -501,7 +501,7 @@ func TestUpdateCertificate_InvalidBody(t *testing.T) {
// Test ArchiveCertificate - success case
func TestArchiveCertificate_Success(t *testing.T) {
mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error {
ArchiveCertificateFn: func(_ context.Context, id string) error {
if id == "mc-prod-001" {
return nil
}
@@ -524,7 +524,7 @@ func TestArchiveCertificate_Success(t *testing.T) {
// Test ArchiveCertificate - not found
func TestArchiveCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error {
ArchiveCertificateFn: func(_ context.Context, id string) error {
return ErrMockNotFound
},
}
@@ -554,7 +554,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
}
mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if certID == "mc-prod-001" {
return []domain.CertificateVersion{ver1}, 1, nil
}
@@ -586,7 +586,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
// Test GetCertificateVersions - not found
func TestGetCertificateVersions_NotFound(t *testing.T) {
mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
return nil, 0, ErrMockNotFound
},
}
@@ -606,7 +606,7 @@ func TestGetCertificateVersions_NotFound(t *testing.T) {
// Test TriggerRenewal - success case
func TestTriggerRenewal_Success(t *testing.T) {
mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error {
TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
if certID == "mc-prod-001" {
return nil
}
@@ -638,7 +638,7 @@ func TestTriggerRenewal_Success(t *testing.T) {
// Test TriggerRenewal - service error
func TestTriggerRenewal_ServiceError(t *testing.T) {
mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error {
TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
return ErrMockServiceFailed
},
}
@@ -658,7 +658,7 @@ func TestTriggerRenewal_ServiceError(t *testing.T) {
// Test TriggerDeployment - success case
func TestTriggerDeployment_Success(t *testing.T) {
mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error {
TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
if certID == "mc-prod-001" {
return nil
}
@@ -695,7 +695,7 @@ func TestTriggerDeployment_Success(t *testing.T) {
// Test TriggerDeployment - without target ID
func TestTriggerDeployment_NoTargetID(t *testing.T) {
mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error {
TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
// Should accept empty targetID (deploy to all)
return nil
},
@@ -716,7 +716,7 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) {
// Test ListCertificates - invalid page parameter
func TestListCertificates_InvalidPageParam(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Should default to page 1
if filter.Page == 1 {
return []domain.ManagedCertificate{}, 0, nil
@@ -740,7 +740,7 @@ func TestListCertificates_InvalidPageParam(t *testing.T) {
// Test ListCertificates - per_page exceeds max
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Should cap perPage at 500
if filter.PerPage == 50 { // defaults to 50 if > 500
return []domain.ManagedCertificate{}, 0, nil
@@ -765,7 +765,7 @@ func TestListCertificates_PerPageExceedsMax(t *testing.T) {
func TestRevokeCertificate_Handler_Success(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
if certID != "mc-prod-001" {
t.Errorf("expected certID mc-prod-001, got %s", certID)
}
@@ -798,7 +798,7 @@ func TestRevokeCertificate_Handler_Success(t *testing.T) {
func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
// Empty reason is OK — service defaults to "unspecified"
return nil
},
@@ -818,7 +818,7 @@ func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("certificate is already revoked")
},
}
@@ -839,7 +839,7 @@ func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("failed to fetch certificate: not found")
},
}
@@ -858,7 +858,7 @@ func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("invalid revocation reason: badReason")
},
}
@@ -922,7 +922,7 @@ func TestRevokeCertificate_Handler_EmptyID(t *testing.T) {
func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("cannot revoke archived certificate")
},
}
@@ -941,7 +941,7 @@ func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("database connection lost")
},
}
@@ -962,7 +962,7 @@ func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
func TestGetCRL_Success(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return []*domain.CertificateRevocation{
{
ID: "rev-1",
@@ -1022,7 +1022,7 @@ func TestGetCRL_Success(t *testing.T) {
func TestGetCRL_Empty(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return nil, nil
},
}
@@ -1047,7 +1047,7 @@ func TestGetCRL_Empty(t *testing.T) {
func TestGetCRL_ServiceError(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return nil, fmt.Errorf("revocation repository not configured")
},
}
@@ -1083,7 +1083,7 @@ func TestGetCRL_MethodNotAllowed(t *testing.T) {
func TestGetDERCRL_Success(t *testing.T) {
derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes
mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
if issuerID == "iss-local" {
return derCRLData, nil
}
@@ -1111,7 +1111,7 @@ func TestGetDERCRL_Success(t *testing.T) {
func TestGetDERCRL_IssuerNotFound(t *testing.T) {
mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found")
},
}
@@ -1130,7 +1130,7 @@ func TestGetDERCRL_IssuerNotFound(t *testing.T) {
func TestGetDERCRL_NotSupported(t *testing.T) {
mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer does not support CRL generation")
},
}
@@ -1165,7 +1165,7 @@ func TestGetDERCRL_MethodNotAllowed(t *testing.T) {
func TestHandleOCSP_Success(t *testing.T) {
ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response
mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
if issuerID == "iss-local" && serialHex == "12345" {
return ocspResponseBytes, nil
}
@@ -1206,7 +1206,7 @@ func TestHandleOCSP_MissingSerial(t *testing.T) {
func TestHandleOCSP_IssuerNotFound(t *testing.T) {
mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found")
},
}
@@ -1225,7 +1225,7 @@ func TestHandleOCSP_IssuerNotFound(t *testing.T) {
func TestHandleOCSP_CertNotFound(t *testing.T) {
mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
return nil, fmt.Errorf("certificate not found")
},
}
@@ -1261,7 +1261,7 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) {
// TestListCertificates_SortParam tests sort parameter parsing and passing to service.
func TestListCertificates_SortParam(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Handler strips the '-' prefix and sets SortDesc = true
if filter.Sort != "notAfter" || !filter.SortDesc {
t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
@@ -1284,7 +1284,7 @@ func TestListCertificates_SortParam(t *testing.T) {
// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending).
func TestListCertificates_SortParam_Ascending(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Sort != "createdAt" || filter.SortDesc {
t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
}
@@ -1309,7 +1309,7 @@ func TestListCertificates_TimeRangeFilters(t *testing.T) {
after := time.Now().AddDate(0, 0, -90)
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.ExpiresBefore == nil {
t.Error("expected ExpiresBefore to be set")
}
@@ -1339,7 +1339,7 @@ func TestListCertificates_CreatedAfterFilter(t *testing.T) {
past := time.Now().AddDate(-1, 0, 0)
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.CreatedAfter == nil {
t.Error("expected CreatedAfter to be set")
}
@@ -1369,7 +1369,7 @@ func TestListCertificates_CursorPagination(t *testing.T) {
}
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{cert}, 1, nil
},
}
@@ -1409,7 +1409,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
}
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if len(filter.Fields) != 2 {
t.Errorf("expected 2 fields, got %d", len(filter.Fields))
}
@@ -1456,7 +1456,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
// TestListCertificates_ProfileFilter tests profile_id filter.
func TestListCertificates_ProfileFilter(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.ProfileID != "prof-standard" {
t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID)
}
@@ -1479,7 +1479,7 @@ func TestListCertificates_ProfileFilter(t *testing.T) {
// TestListCertificates_AgentIDFilter tests agent_id filter.
func TestListCertificates_AgentIDFilter(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.AgentID != "agent-prod-001" {
t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID)
}
@@ -1502,7 +1502,7 @@ func TestListCertificates_AgentIDFilter(t *testing.T) {
// TestListCertificates_CombinedFilters tests multiple filters together.
func TestListCertificates_CombinedFilters(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" {
t.Error("expected all filters to be set")
}
@@ -1540,7 +1540,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
}
mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
if certID != "mc-prod-001" {
return nil, ErrMockNotFound
}
@@ -1576,7 +1576,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate.
func TestGetCertificateDeployments_NotFound(t *testing.T) {
mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
return nil, fmt.Errorf("certificate not found")
},
}
@@ -1596,7 +1596,7 @@ func TestGetCertificateDeployments_NotFound(t *testing.T) {
// TestGetCertificateDeployments_Empty tests successful response with no deployments.
func TestGetCertificateDeployments_Empty(t *testing.T) {
mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
if certID == "mc-no-deployments" {
return []domain.DeploymentTarget{}, nil
}
+28 -27
View File
@@ -1,6 +1,7 @@
package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
@@ -15,20 +16,20 @@ import (
// CertificateService defines the service interface for certificate operations.
type CertificateService interface {
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificate(id string) (*domain.ManagedCertificate, error)
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificate(id string) error
GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(certID string) error
TriggerDeployment(certID string, targetID string) error
RevokeCertificate(certID string, reason string) error
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
GenerateDERCRL(issuerID string) ([]byte, error)
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error)
GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error)
ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error)
CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificate(ctx context.Context, id string) error
GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(ctx context.Context, certID string, actor string) error
TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error
RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error
GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error)
GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error)
GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
}
// CertificateHandler handles HTTP requests for certificate operations.
@@ -128,7 +129,7 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
filter.Fields = strings.Split(fieldsStr, ",")
}
certs, total, err := h.svc.ListCertificatesWithFilter(filter)
certs, total, err := h.svc.ListCertificatesWithFilter(r.Context(), filter)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
return
@@ -186,7 +187,7 @@ func (h CertificateHandler) GetCertificate(w http.ResponseWriter, r *http.Reques
return
}
cert, err := h.svc.GetCertificate(id)
cert, err := h.svc.GetCertificate(r.Context(), id)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
@@ -241,7 +242,7 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req
return
}
created, err := h.svc.CreateCertificate(cert)
created, err := h.svc.CreateCertificate(r.Context(), cert)
if err != nil {
slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name)
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID)
@@ -295,7 +296,7 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req
}
}
updated, err := h.svc.UpdateCertificate(id, cert)
updated, err := h.svc.UpdateCertificate(r.Context(), id, cert)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -325,7 +326,7 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re
return
}
if err := h.svc.ArchiveCertificate(id); err != nil {
if err := h.svc.ArchiveCertificate(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
@@ -370,7 +371,7 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt
}
}
versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage)
versions, total, err := h.svc.GetCertificateVersions(r.Context(), certID, page, perPage)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -410,7 +411,7 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques
}
certID := parts[0]
if err := h.svc.TriggerRenewal(certID); err != nil {
if err := h.svc.TriggerRenewal(r.Context(), certID, "api"); err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -466,7 +467,7 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req
}
}
if err := h.svc.TriggerDeployment(certID, req.TargetID); err != nil {
if err := h.svc.TriggerDeployment(r.Context(), certID, req.TargetID, "api"); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID)
return
}
@@ -508,7 +509,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req
}
}
if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil {
if err := h.svc.RevokeCertificate(r.Context(), certID, req.Reason, "api"); err != nil {
// Distinguish between client errors and server errors
errMsg := err.Error()
if strings.Contains(errMsg, "already revoked") ||
@@ -540,7 +541,7 @@ func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) {
requestID := middleware.GetRequestID(r.Context())
revocations, err := h.svc.GetRevokedCertificates()
revocations, err := h.svc.GetRevokedCertificates(r.Context())
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
return
@@ -585,7 +586,7 @@ func (h CertificateHandler) GetDERCRL(w http.ResponseWriter, r *http.Request) {
return
}
derBytes, err := h.svc.GenerateDERCRL(issuerID)
derBytes, err := h.svc.GenerateDERCRL(r.Context(), issuerID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
@@ -627,7 +628,7 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) {
issuerID := parts[0]
serialHex := parts[1]
derBytes, err := h.svc.GetOCSPResponse(issuerID, serialHex)
derBytes, err := h.svc.GetOCSPResponse(r.Context(), issuerID, serialHex)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
@@ -667,7 +668,7 @@ func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *
}
certID := parts[0]
deployments, err := h.svc.GetCertificateDeployments(certID)
deployments, err := h.svc.GetCertificateDeployments(r.Context(), certID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
+13 -13
View File
@@ -41,7 +41,7 @@ func (s *CAOperationsSvc) SetIssuerRegistry(registry *IssuerRegistry) {
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
func (s *CAOperationsSvc) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured")
}
@@ -54,7 +54,7 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found: %s", issuerID)
}
revocations, err := s.revocationRepo.ListAll(context.Background())
revocations, err := s.revocationRepo.ListAll(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list revocations: %w", err)
}
@@ -69,9 +69,9 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
// Check short-lived exemption: look up the cert's profile
if s.profileRepo != nil && s.certRepo != nil {
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
cert, err := s.certRepo.Get(ctx, rev.CertificateID)
if err == nil && cert.CertificateProfileID != "" {
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
if err == nil && profile.IsShortLived() {
slog.Debug("skipping short-lived cert from CRL",
"certificate_id", rev.CertificateID,
@@ -92,11 +92,11 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
})
}
return issuerConn.GenerateCRL(context.Background(), entries)
return issuerConn.GenerateCRL(ctx, entries)
}
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured")
}
@@ -120,13 +120,13 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
// Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers
// are unique only within a single issuer. The OCSP URL path carries issuer_id,
// so we scope the lookup to avoid cross-issuer collisions.
rev, _ := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
rev, _ := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if rev != nil {
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
cert, err := s.certRepo.Get(ctx, rev.CertificateID)
if err == nil && cert.CertificateProfileID != "" {
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
if err == nil && profile.IsShortLived() {
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial,
CertStatus: 0, // good — short-lived exemption
ThisUpdate: now,
@@ -138,10 +138,10 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
}
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
rev, err := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if err != nil {
// Not revoked — return "good" status
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial,
CertStatus: 0, // good
ThisUpdate: now,
@@ -150,7 +150,7 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
}
// Revoked
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial,
CertStatus: 1, // revoked
RevokedAt: rev.RevokedAt,
+5 -4
View File
@@ -3,6 +3,7 @@
package service
import (
"context"
"log/slog"
"testing"
"time"
@@ -48,7 +49,7 @@ func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) {
},
}
crl, err := caSvc.GenerateDERCRL("iss-local")
crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -71,7 +72,7 @@ func TestCAOperationsSvc_GenerateDERCRL_EmptyCRL(t *testing.T) {
// No revoked certs for this issuer
revocationRepo.Revocations = []*domain.CertificateRevocation{}
crl, err := caSvc.GenerateDERCRL("iss-local")
crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -112,7 +113,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) {
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
// Request OCSP response for good cert
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-GOOD-001")
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -165,7 +166,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) {
}
// Request OCSP response for revoked cert
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001")
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
+31 -46
View File
@@ -71,8 +71,8 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif
// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20).
// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers).
func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
certs, total, err := s.certRepo.List(context.Background(), filter)
func (s *CertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
certs, total, err := s.certRepo.List(ctx, filter)
if err != nil {
return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err)
}
@@ -206,10 +206,10 @@ func (s *CertificateService) GetVersions(ctx context.Context, certID string) ([]
return versions, nil
}
// TriggerRenewalWithActor initiates a renewal job if the certificate is eligible.
// TriggerRenewal initiates a renewal job if the certificate is eligible.
// Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor
// can pick it up and route it through the issuer connector.
func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID string, actor string) error {
func (s *CertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
cert, err := s.certRepo.Get(ctx, certID)
if err != nil {
return fmt.Errorf("failed to fetch certificate: %w", err)
@@ -283,8 +283,11 @@ func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID
return nil
}
// TriggerDeploymentWithActor creates deployment jobs for all targets of a certificate.
func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, certID string, actor string) error {
// TriggerDeployment creates deployment jobs for all targets of a certificate.
// The targetID parameter is accepted from the handler interface but currently unused;
// deployment coordination happens per-certificate across all of its targets.
func (s *CertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
_ = targetID
cert, err := s.certRepo.Get(ctx, certID)
if err != nil {
return fmt.Errorf("failed to fetch certificate: %w", err)
@@ -306,7 +309,7 @@ func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, cer
}
// ListCertificates returns paginated certificates with optional filtering (handler interface method).
func (s *CertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
func (s *CertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if page < 1 {
page = 1
}
@@ -325,7 +328,7 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
PerPage: perPage,
}
certs, total, err := s.certRepo.List(context.Background(), filter)
certs, total, err := s.certRepo.List(ctx, filter)
if err != nil {
return nil, 0, fmt.Errorf("failed to list certificates: %w", err)
}
@@ -341,12 +344,12 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
}
// GetCertificate returns a single certificate (handler interface method).
func (s *CertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
return s.certRepo.Get(context.Background(), id)
func (s *CertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
return s.certRepo.Get(ctx, id)
}
// CreateCertificate creates a new certificate (handler interface method).
func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
func (s *CertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if cert.ID == "" {
cert.ID = generateID("cert")
}
@@ -365,16 +368,14 @@ func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (
if cert.Tags == nil {
cert.Tags = make(map[string]string)
}
if err := s.certRepo.Create(context.Background(), &cert); err != nil {
if err := s.certRepo.Create(ctx, &cert); err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err)
}
return &cert, nil
}
// UpdateCertificate modifies a certificate (handler interface method).
func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
ctx := context.Background()
func (s *CertificateService) UpdateCertificate(ctx context.Context, id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
// Fetch existing certificate so partial updates don't zero out fields
existing, err := s.certRepo.Get(ctx, id)
if err != nil {
@@ -425,12 +426,12 @@ func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCe
}
// ArchiveCertificate marks a certificate as archived (handler interface method).
func (s *CertificateService) ArchiveCertificate(id string) error {
return s.certRepo.Archive(context.Background(), id)
func (s *CertificateService) ArchiveCertificate(ctx context.Context, id string) error {
return s.certRepo.Archive(ctx, id)
}
// GetCertificateVersions returns certificate versions (handler interface method).
func (s *CertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
func (s *CertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if page < 1 {
page = 1
}
@@ -438,7 +439,7 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
perPage = 50
}
versions, err := s.certRepo.ListVersions(context.Background(), certID)
versions, err := s.certRepo.ListVersions(ctx, certID)
if err != nil {
return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err)
}
@@ -463,24 +464,8 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
return result, total, nil
}
// TriggerRenewal initiates renewal (handler interface method).
func (s *CertificateService) TriggerRenewal(certID string) error {
return s.TriggerRenewalWithActor(context.Background(), certID, "api")
}
// TriggerDeployment triggers deployment (handler interface method).
func (s *CertificateService) TriggerDeployment(certID string, targetID string) error {
return s.TriggerDeploymentWithActor(context.Background(), certID, "api")
}
// RevokeCertificate revokes a certificate with the given reason (handler interface method).
func (s *CertificateService) RevokeCertificate(certID string, reason string) error {
return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api")
}
// RevokeCertificateWithActor performs revocation with actor tracking.
// Delegates to RevocationSvc.
func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
// RevokeCertificate performs revocation with actor tracking. Delegates to RevocationSvc.
func (s *CertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
if s.revSvc == nil {
return fmt.Errorf("revocation service not configured")
}
@@ -489,35 +474,35 @@ func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, cer
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
// Delegates to RevocationSvc.
func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
func (s *CertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if s.revSvc == nil {
return nil, fmt.Errorf("revocation service not configured")
}
return s.revSvc.GetRevokedCertificates()
return s.revSvc.GetRevokedCertificates(ctx)
}
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
// Delegates to CAOperationsSvc.
func (s *CertificateService) GenerateDERCRL(issuerID string) ([]byte, error) {
func (s *CertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if s.caSvc == nil {
return nil, fmt.Errorf("CA operations service not configured")
}
return s.caSvc.GenerateDERCRL(issuerID)
return s.caSvc.GenerateDERCRL(ctx, issuerID)
}
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
// Delegates to CAOperationsSvc.
func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
func (s *CertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if s.caSvc == nil {
return nil, fmt.Errorf("CA operations service not configured")
}
return s.caSvc.GetOCSPResponse(issuerID, serialHex)
return s.caSvc.GetOCSPResponse(ctx, issuerID, serialHex)
}
// GetCertificateDeployments returns all deployment targets for a certificate (M20).
func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
func (s *CertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
// Verify certificate exists
_, err := s.certRepo.Get(context.Background(), certID)
_, err := s.certRepo.Get(ctx, certID)
if err != nil {
return nil, fmt.Errorf("certificate not found: %w", err)
}
@@ -527,7 +512,7 @@ func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.
}
// Get targets from repository
targets, err := s.targetRepo.ListByCertificate(context.Background(), certID)
targets, err := s.targetRepo.ListByCertificate(ctx, certID)
if err != nil {
return nil, fmt.Errorf("failed to list deployment targets: %w", err)
}
+11 -11
View File
@@ -34,7 +34,7 @@ func TestCertificateService_RevokeCertificate_RevocationSvcNil(t *testing.T) {
certRepo.AddCert(cert)
// Call RevokeCertificateWithActor with nil RevocationSvc
err := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
err := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
// Assert: Should return error, NOT panic
if err == nil {
@@ -64,7 +64,7 @@ func TestCertificateService_GenerateDERCRL_CAOpsSvcNil(t *testing.T) {
// Note: NOT calling certService.SetCAOperationsSvc(...)
// Call GenerateDERCRL with nil CAOperationsSvc
_, err := certService.GenerateDERCRL("iss-local")
_, err := certService.GenerateDERCRL(context.Background(), "iss-local")
// Assert: Should return error, NOT panic
if err == nil {
@@ -94,7 +94,7 @@ func TestCertificateService_GetOCSPResponse_CAOpsSvcNil(t *testing.T) {
// Note: NOT calling certService.SetCAOperationsSvc(...)
// Call GetOCSPResponse with nil CAOperationsSvc
_, err := certService.GetOCSPResponse("iss-local", "serial123")
_, err := certService.GetOCSPResponse(context.Background(), "iss-local", "serial123")
// Assert: Should return error, NOT panic
if err == nil {
@@ -124,7 +124,7 @@ func TestCertificateService_GetRevokedCertificates_RevocationSvcNil(t *testing.T
// Note: NOT calling certService.SetRevocationSvc(...)
// Call GetRevokedCertificates with nil RevocationSvc
_, err := certService.GetRevokedCertificates()
_, err := certService.GetRevokedCertificates(context.Background())
// Assert: Should return error, NOT panic
if err == nil {
@@ -177,7 +177,7 @@ func TestCertificateService_GetCertificateDeployments_Success(t *testing.T) {
targetRepo.AddTarget(target2)
// Call GetCertificateDeployments
deployments, err := certService.GetCertificateDeployments("cert-1")
deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return deployment list successfully
if err != nil {
@@ -218,7 +218,7 @@ func TestCertificateService_GetCertificateDeployments_RepositoryError(t *testing
certRepo.AddCert(cert)
// Call GetCertificateDeployments with repo error
_, err := certService.GetCertificateDeployments("cert-1")
_, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return error, NOT panic
if err == nil {
@@ -247,7 +247,7 @@ func TestCertificateService_GetCertificateDeployments_CertNotFound(t *testing.T)
certService.SetTargetRepo(targetRepo)
// Call GetCertificateDeployments with nonexistent certificate
_, err := certService.GetCertificateDeployments("nonexistent-cert")
_, err := certService.GetCertificateDeployments(context.Background(), "nonexistent-cert")
// Assert: Should return error
if err == nil {
@@ -283,7 +283,7 @@ func TestCertificateService_GetCertificateDeployments_NilTargetRepo(t *testing.T
certRepo.AddCert(cert)
// Call GetCertificateDeployments with nil TargetRepo
deployments, err := certService.GetCertificateDeployments("cert-1")
deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return empty list gracefully (not panic)
if err != nil {
@@ -337,19 +337,19 @@ func TestCertificateService_Multiple_NilSafetyChecks(t *testing.T) {
revSvc.SetIssuerRegistry(registry)
// Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set)
errRevoke := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
errRevoke := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
if errRevoke != nil {
t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke)
}
// Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil)
_, errCRL := certService.GenerateDERCRL("iss-local")
_, errCRL := certService.GenerateDERCRL(context.Background(), "iss-local")
if errCRL == nil {
t.Fatal("GenerateDERCRL expected error, got nil")
}
// Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil)
_, errOCSP := certService.GetOCSPResponse("iss-local", "ABC123")
_, errOCSP := certService.GetOCSPResponse(context.Background(), "iss-local", "ABC123")
if errOCSP == nil {
t.Fatal("GetOCSPResponse expected error, got nil")
}
+4 -3
View File
@@ -294,7 +294,7 @@ func TestTriggerRenewal(t *testing.T) {
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
if err != nil {
t.Fatalf("TriggerRenewal failed: %v", err)
}
@@ -333,13 +333,14 @@ func TestTriggerRenewal_Archived(t *testing.T) {
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
if err == nil {
t.Fatal("expected error for archived certificate")
}
}
func TestListCertificates(t *testing.T) {
ctx := context.Background()
now := time.Now()
cert1 := &domain.ManagedCertificate{
ID: "cert-001",
@@ -369,7 +370,7 @@ func TestListCertificates(t *testing.T) {
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50)
certs, total, err := certService.ListCertificates(ctx, "", "", "", "", "", 1, 50)
if err != nil {
t.Fatalf("ListCertificates failed: %v", err)
}
+2 -2
View File
@@ -151,9 +151,9 @@ func (s *RevocationSvc) RevokeCertificateWithActor(ctx context.Context, certID s
}
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
func (s *RevocationSvc) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
func (s *RevocationSvc) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured")
}
return s.revocationRepo.ListAll(context.Background())
return s.revocationRepo.ListAll(ctx)
}
+1 -1
View File
@@ -122,7 +122,7 @@ func TestRevocationSvc_GetRevokedCertificates_Success(t *testing.T) {
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
}
revocations, err := revSvc.GetRevokedCertificates()
revocations, err := revSvc.GetRevokedCertificates(context.Background())
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
+23 -23
View File
@@ -62,7 +62,7 @@ func TestRevokeCertificate_Success(t *testing.T) {
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
// Revoke
err := svc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -125,7 +125,7 @@ func TestRevokeCertificate_DefaultReason(t *testing.T) {
}
// Revoke with empty reason — should default to "unspecified"
err := svc.RevokeCertificateWithActor(context.Background(), "cert-2", "", "api")
err := svc.RevokeCertificate(context.Background(), "cert-2", "", "api")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -158,7 +158,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-3", "superseded", "admin")
if err == nil {
t.Fatal("expected error for already revoked certificate")
}
@@ -179,7 +179,7 @@ func TestRevokeCertificate_ArchivedCert(t *testing.T) {
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-4", "keyCompromise", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-4", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error for archived certificate")
}
@@ -200,7 +200,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-5", "notAValidReason", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-5", "notAValidReason", "admin")
if err == nil {
t.Fatal("expected error for invalid reason")
}
@@ -212,7 +212,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
func TestRevokeCertificate_NotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
err := svc.RevokeCertificateWithActor(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
err := svc.RevokeCertificate(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error for nonexistent certificate")
}
@@ -231,7 +231,7 @@ func TestRevokeCertificate_NoVersion(t *testing.T) {
certRepo.AddCert(cert)
// No versions added — should fail
err := svc.RevokeCertificateWithActor(context.Background(), "cert-6", "keyCompromise", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-6", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error when no certificate version exists")
}
@@ -258,7 +258,7 @@ func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
{ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-7", "cessationOfOperation", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-7", "cessationOfOperation", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -293,7 +293,7 @@ func TestRevokeCertificate_WithNotificationService(t *testing.T) {
{ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-8", "keyCompromise", "admin")
err := svc.RevokeCertificate(context.Background(), "cert-8", "keyCompromise", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -336,7 +336,7 @@ func TestRevokeCertificate_AllValidReasons(t *testing.T) {
{ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-"+reason, reason, "admin")
err := svc.RevokeCertificate(context.Background(), "cert-"+reason, reason, "admin")
if err != nil {
t.Fatalf("expected no error for reason %s, got: %v", reason, err)
}
@@ -358,7 +358,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
}
revocations, err := svc.GetRevokedCertificates()
revocations, err := svc.GetRevokedCertificates(context.Background())
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -370,7 +370,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
func TestGetRevokedCertificates_Empty(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
revocations, err := svc.GetRevokedCertificates()
revocations, err := svc.GetRevokedCertificates(context.Background())
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -390,7 +390,7 @@ func TestGetRevokedCertificates_NoRepo(t *testing.T) {
svc := NewCertificateService(certRepo, policyService, auditService)
// Do NOT set revocation repo
_, err := svc.GetRevokedCertificates()
_, err := svc.GetRevokedCertificates(context.Background())
if err == nil {
t.Fatal("expected error when revocation repo not configured")
}
@@ -411,8 +411,8 @@ func TestRevokeCertificate_HandlerInterfaceMethod(t *testing.T) {
{ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()},
}
// Test the handler interface method (no actor param)
err := svc.RevokeCertificate("cert-handler", "superseded")
// Test the handler interface method (actor collapsed to required positional arg per D-2)
err := svc.RevokeCertificate(context.Background(), "cert-handler", "superseded", "api")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
@@ -449,7 +449,7 @@ func TestGenerateDERCRL_Success(t *testing.T) {
},
}
crl, err := svc.GenerateDERCRL("iss-local")
crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -472,7 +472,7 @@ func TestGenerateDERCRL_EmptyCRL(t *testing.T) {
// No revoked certs for this issuer
revocationRepo.Revocations = []*domain.CertificateRevocation{}
crl, err := svc.GenerateDERCRL("iss-local")
crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -493,7 +493,7 @@ func TestGenerateDERCRL_IssuerNotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
// Try to generate CRL for unknown issuer
crl, err := svc.GenerateDERCRL("iss-unknown")
crl, err := svc.GenerateDERCRL(context.Background(), "iss-unknown")
// Should return error or nil CRL depending on implementation
if crl != nil && err == nil {
@@ -527,7 +527,7 @@ func TestGetOCSPResponse_Good(t *testing.T) {
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
// Request OCSP response for good cert
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-GOOD-001")
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -580,7 +580,7 @@ func TestGetOCSPResponse_Revoked(t *testing.T) {
}
// Request OCSP response for revoked cert
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001")
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
@@ -597,7 +597,7 @@ func TestGetOCSPResponse_Unknown(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
// Request OCSP response for unknown cert
resp, err := svc.GetOCSPResponse("iss-local", "UNKNOWN-SERIAL")
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "UNKNOWN-SERIAL")
if err != nil {
t.Fatalf("expected no error (should return unknown status), got: %v", err)
@@ -615,7 +615,7 @@ func TestGetOCSPResponse_IssuerNotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
// Request OCSP response for unknown issuer
resp, err := svc.GetOCSPResponse("iss-unknown", "SOME-SERIAL")
resp, err := svc.GetOCSPResponse(context.Background(), "iss-unknown", "SOME-SERIAL")
// Should return error since issuer doesn't exist
if err == nil && resp != nil {
@@ -629,7 +629,7 @@ func TestGetOCSPResponse_InvalidSerial(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
// Request OCSP response with invalid serial format
resp, err := svc.GetOCSPResponse("iss-local", "")
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "")
if err == nil && resp != nil {
// Empty serial might return unknown status; that's ok