From ad2734c10a91b742f00c7d83754ed2f891cf056b Mon Sep 17 00:00:00 2001 From: Shankar Date: Sat, 18 Apr 2026 00:29:37 +0000 Subject: [PATCH] fix(m-2): thread context through CertificateService cluster Collapses CertificateService, RevocationSvc, and CAOperationsSvc to ctx-accepting method signatures. Removes context.Background() synthesis at 24 internal call sites across certificate.go, revocation_svc.go, and ca_operations.go. - Primary repo calls inherit request cancellation via the passed ctx. - Audit and notification dispatches use context.WithoutCancel(ctx) so they survive client disconnect. - Collapses TriggerRenewal/TriggerRenewalWithActor, TriggerDeployment/TriggerDeploymentWithActor, and RevokeCertificate/RevokeCertificateWithActor sibling pairs into single canonical ctx-accepting methods (decisions D-1, D-2). Handlers pass r.Context(). Mocks and tests updated to match new signatures. No HTTP surface change, no OpenAPI change. PR 1 of 6 in the M-2 remediation chain. Master green at this commit. Refs: certctl-audit-report.md M-2 (L143, L224) --- internal/api/handler/adversarial_path_test.go | 13 +- .../api/handler/adversarial_query_test.go | 23 +-- .../api/handler/certificate_handler_test.go | 176 +++++++++--------- internal/api/handler/certificates.go | 55 +++--- internal/service/ca_operations.go | 26 +-- internal/service/ca_operations_test.go | 9 +- internal/service/certificate.go | 77 +++----- .../service/certificate_nil_safety_test.go | 22 +-- internal/service/certificate_test.go | 7 +- internal/service/revocation_svc.go | 4 +- internal/service/revocation_svc_test.go | 2 +- internal/service/revocation_test.go | 46 ++--- 12 files changed, 225 insertions(+), 235 deletions(-) diff --git a/internal/api/handler/adversarial_path_test.go b/internal/api/handler/adversarial_path_test.go index 4fb94d0..8d69c3f 100644 --- a/internal/api/handler/adversarial_path_test.go +++ b/internal/api/handler/adversarial_path_test.go @@ -27,6 +27,7 @@ package handler import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -120,7 +121,7 @@ func TestGetCertificate_PathInjection(t *testing.T) { handler, mock := newCertHandlerWithMock() // Force a 404 so we can distinguish "service was called" from // "parser accepted the ID"; a 200 with null body is also fine. - mock.GetCertificateFn = func(id string) (*domain.ManagedCertificate, error) { + mock.GetCertificateFn = func(_ context.Context, id string) (*domain.ManagedCertificate, error) { return nil, ErrMockNotFound } @@ -156,7 +157,7 @@ func TestUpdateCertificate_PathInjection(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.UpdateCertificateFn = func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + mock.UpdateCertificateFn = func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { return nil, ErrMockNotFound } @@ -184,7 +185,7 @@ func TestArchiveCertificate_PathInjection(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ArchiveCertificateFn = func(id string) error { return ErrMockNotFound } + mock.ArchiveCertificateFn = func(_ context.Context, id string) error { return ErrMockNotFound } req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil) req.URL.Path = "/api/v1/certificates/" + tc.input @@ -227,7 +228,7 @@ func TestGetCertificateVersions_MultiSegment(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.GetCertificateVersionsFn = func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { + mock.GetCertificateVersionsFn = func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { return []domain.CertificateVersion{}, 0, nil } @@ -277,7 +278,7 @@ func TestHandleOCSP_MultiSegment(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.GetOCSPResponseFn = func(issuerID, serialHex string) ([]byte, error) { + mock.GetOCSPResponseFn = func(_ context.Context, issuerID, serialHex string) ([]byte, error) { return nil, ErrMockNotFound } @@ -311,7 +312,7 @@ func TestGetDERCRL_IssuerPathInjection(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.GenerateDERCRLFn = func(issuerID string) ([]byte, error) { + mock.GenerateDERCRLFn = func(_ context.Context, issuerID string) ([]byte, error) { return nil, ErrMockNotFound } diff --git a/internal/api/handler/adversarial_query_test.go b/internal/api/handler/adversarial_query_test.go index 570d8d5..13c57ec 100644 --- a/internal/api/handler/adversarial_query_test.go +++ b/internal/api/handler/adversarial_query_test.go @@ -19,6 +19,7 @@ package handler import ( "bytes" + "context" "fmt" "net/http" "net/http/httptest" @@ -76,7 +77,7 @@ func TestListCertificates_PaginationAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Sanity: page/perPage on the filter must never be negative // and perPage must never exceed 500 after parsing. if filter.Page < 1 { @@ -133,7 +134,7 @@ func TestListCertificates_SortAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{}, 0, nil } @@ -175,7 +176,7 @@ func TestListCertificates_FieldsAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{}, 0, nil } @@ -219,7 +220,7 @@ func TestListCertificates_TimeRangeAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{}, 0, nil } @@ -263,7 +264,7 @@ func TestListCertificates_CursorAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{}, 0, nil } @@ -314,7 +315,7 @@ func TestListCertificates_FilterInjection(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{}, 0, nil } @@ -374,7 +375,7 @@ func TestCreateCertificate_BodyAbuse(t *testing.T) { }() handler, mock := newCertHandlerWithMock() - mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { // If we ever reach this, the handler accepted a malformed // body. Return a sentinel that passes but flag it. c := cert @@ -419,7 +420,7 @@ func TestCreateCertificate_HugeBody(t *testing.T) { sb.WriteString(`]}`) handler, mock := newCertHandlerWithMock() - mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { c := cert c.ID = "mc-huge" return &c, nil @@ -476,7 +477,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) { handler, mock := newCertHandlerWithMock() // The mock always returns "invalid revocation reason" so we // verify the handler's errMsg→status mapping turns it into a 400. - mock.RevokeCertificateFn = func(id string, reason string) error { + mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error { // The service uses domain.IsValidRevocationReason. If we got // through to here with something bogus, simulate a real // service error. @@ -500,7 +501,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) { // service error message, which is fragile — this test catches regressions. func TestRevokeCertificate_AlreadyRevoked(t *testing.T) { handler, mock := newCertHandlerWithMock() - mock.RevokeCertificateFn = func(id string, reason string) error { + mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error { return fmt.Errorf("cannot revoke: certificate is already revoked") } @@ -520,7 +521,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) { // TestRevokeCertificate_NotFound verifies 404 mapping. func TestRevokeCertificate_NotFound(t *testing.T) { handler, mock := newCertHandlerWithMock() - mock.RevokeCertificateFn = func(id string, reason string) error { + mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error { return fmt.Errorf("certificate not found") } diff --git a/internal/api/handler/certificate_handler_test.go b/internal/api/handler/certificate_handler_test.go index bb0b1a4..7da514f 100644 --- a/internal/api/handler/certificate_handler_test.go +++ b/internal/api/handler/certificate_handler_test.go @@ -17,116 +17,116 @@ import ( // MockCertificateService is a mock implementation of CertificateService interface. type MockCertificateService struct { - ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) - ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) - GetCertificateFn func(id string) (*domain.ManagedCertificate, error) - CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - ArchiveCertificateFn func(id string) error - GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) - TriggerRenewalFn func(certID string) error - TriggerDeploymentFn func(certID string, targetID string) error - RevokeCertificateFn func(certID string, reason string) error - GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error) - GenerateDERCRLFn func(issuerID string) ([]byte, error) - GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error) - GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error) + ListCertificatesFn func(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) + ListCertificatesWithFilterFn func(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) + GetCertificateFn func(ctx context.Context, id string) (*domain.ManagedCertificate, error) + CreateCertificateFn func(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + UpdateCertificateFn func(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + ArchiveCertificateFn func(ctx context.Context, id string) error + GetCertificateVersionsFn func(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) + TriggerRenewalFn func(ctx context.Context, certID string, actor string) error + TriggerDeploymentFn func(ctx context.Context, certID string, targetID string, actor string) error + RevokeCertificateFn func(ctx context.Context, certID string, reason string, actor string) error + GetRevokedCertificatesFn func(ctx context.Context) ([]*domain.CertificateRevocation, error) + GenerateDERCRLFn func(ctx context.Context, issuerID string) ([]byte, error) + GetOCSPResponseFn func(ctx context.Context, issuerID string, serialHex string) ([]byte, error) + GetCertificateDeploymentsFn func(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) } -func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { +func (m *MockCertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { if m.ListCertificatesFn != nil { - return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage) + return m.ListCertificatesFn(ctx, status, environment, ownerID, teamID, issuerID, page, perPage) } return nil, 0, nil } -func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) { +func (m *MockCertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) { if m.GetCertificateFn != nil { - return m.GetCertificateFn(id) + return m.GetCertificateFn(ctx, id) } return nil, nil } -func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { +func (m *MockCertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { if m.CreateCertificateFn != nil { - return m.CreateCertificateFn(cert) + return m.CreateCertificateFn(ctx, cert) } return nil, nil } -func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { +func (m *MockCertificateService) UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { if m.UpdateCertificateFn != nil { - return m.UpdateCertificateFn(id, cert) + return m.UpdateCertificateFn(ctx, id, cert) } return nil, nil } -func (m *MockCertificateService) ArchiveCertificate(id string) error { +func (m *MockCertificateService) ArchiveCertificate(ctx context.Context, id string) error { if m.ArchiveCertificateFn != nil { - return m.ArchiveCertificateFn(id) + return m.ArchiveCertificateFn(ctx, id) } return nil } -func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { +func (m *MockCertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { if m.GetCertificateVersionsFn != nil { - return m.GetCertificateVersionsFn(certID, page, perPage) + return m.GetCertificateVersionsFn(ctx, certID, page, perPage) } return nil, 0, nil } -func (m *MockCertificateService) TriggerRenewal(certID string) error { +func (m *MockCertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error { if m.TriggerRenewalFn != nil { - return m.TriggerRenewalFn(certID) + return m.TriggerRenewalFn(ctx, certID, actor) } return nil } -func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error { +func (m *MockCertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error { if m.TriggerDeploymentFn != nil { - return m.TriggerDeploymentFn(certID, targetID) + return m.TriggerDeploymentFn(ctx, certID, targetID, actor) } return nil } -func (m *MockCertificateService) RevokeCertificate(certID string, reason string) error { +func (m *MockCertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error { if m.RevokeCertificateFn != nil { - return m.RevokeCertificateFn(certID, reason) + return m.RevokeCertificateFn(ctx, certID, reason, actor) } return nil } -func (m *MockCertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { +func (m *MockCertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) { if m.GetRevokedCertificatesFn != nil { - return m.GetRevokedCertificatesFn() + return m.GetRevokedCertificatesFn(ctx) } return nil, nil } -func (m *MockCertificateService) GenerateDERCRL(issuerID string) ([]byte, error) { +func (m *MockCertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) { if m.GenerateDERCRLFn != nil { - return m.GenerateDERCRLFn(issuerID) + return m.GenerateDERCRLFn(ctx, issuerID) } return nil, nil } -func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { +func (m *MockCertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) { if m.GetOCSPResponseFn != nil { - return m.GetOCSPResponseFn(issuerID, serialHex) + return m.GetOCSPResponseFn(ctx, issuerID, serialHex) } return nil, nil } -func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { +func (m *MockCertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if m.ListCertificatesWithFilterFn != nil { - return m.ListCertificatesWithFilterFn(filter) + return m.ListCertificatesWithFilterFn(ctx, filter) } return nil, 0, nil } -func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { +func (m *MockCertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) { if m.GetCertificateDeploymentsFn != nil { - return m.GetCertificateDeploymentsFn(certID) + return m.GetCertificateDeploymentsFn(ctx, certID) } return nil, nil } @@ -158,7 +158,7 @@ func TestListCertificates_Success(t *testing.T) { } mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.Page == 1 && filter.PerPage == 50 { return []domain.ManagedCertificate{cert1, cert2}, 2, nil } @@ -197,7 +197,7 @@ func TestListCertificates_Success(t *testing.T) { // Test ListCertificates - with filters func TestListCertificates_WithFilters(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.Status == "Active" && filter.Environment == "prod" { return []domain.ManagedCertificate{}, 0, nil } @@ -236,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) { // Test ListCertificates - service error func TestListCertificates_ServiceError(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return nil, 0, ErrMockServiceFailed }, } @@ -266,7 +266,7 @@ func TestGetCertificate_Success(t *testing.T) { } mock := &MockCertificateService{ - GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) { + GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) { if id == "mc-prod-001" { return cert, nil } @@ -298,7 +298,7 @@ func TestGetCertificate_Success(t *testing.T) { // Test GetCertificate - not found func TestGetCertificate_NotFound(t *testing.T) { mock := &MockCertificateService{ - GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) { + GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) { return nil, ErrMockNotFound }, } @@ -345,7 +345,7 @@ func TestCreateCertificate_Success(t *testing.T) { } mock := &MockCertificateService{ - CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { return created, nil }, } @@ -403,7 +403,7 @@ func TestCreateCertificate_InvalidBody(t *testing.T) { // Test CreateCertificate - service error func TestCreateCertificate_ServiceError(t *testing.T) { mock := &MockCertificateService{ - CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { return nil, ErrMockServiceFailed }, } @@ -445,7 +445,7 @@ func TestUpdateCertificate_Success(t *testing.T) { } mock := &MockCertificateService{ - UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { + UpdateCertificateFn: func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { if id == "mc-prod-001" { return updated, nil } @@ -501,7 +501,7 @@ func TestUpdateCertificate_InvalidBody(t *testing.T) { // Test ArchiveCertificate - success case func TestArchiveCertificate_Success(t *testing.T) { mock := &MockCertificateService{ - ArchiveCertificateFn: func(id string) error { + ArchiveCertificateFn: func(_ context.Context, id string) error { if id == "mc-prod-001" { return nil } @@ -524,7 +524,7 @@ func TestArchiveCertificate_Success(t *testing.T) { // Test ArchiveCertificate - not found func TestArchiveCertificate_NotFound(t *testing.T) { mock := &MockCertificateService{ - ArchiveCertificateFn: func(id string) error { + ArchiveCertificateFn: func(_ context.Context, id string) error { return ErrMockNotFound }, } @@ -554,7 +554,7 @@ func TestGetCertificateVersions_Success(t *testing.T) { } mock := &MockCertificateService{ - GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { + GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { if certID == "mc-prod-001" { return []domain.CertificateVersion{ver1}, 1, nil } @@ -586,7 +586,7 @@ func TestGetCertificateVersions_Success(t *testing.T) { // Test GetCertificateVersions - not found func TestGetCertificateVersions_NotFound(t *testing.T) { mock := &MockCertificateService{ - GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { + GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { return nil, 0, ErrMockNotFound }, } @@ -606,7 +606,7 @@ func TestGetCertificateVersions_NotFound(t *testing.T) { // Test TriggerRenewal - success case func TestTriggerRenewal_Success(t *testing.T) { mock := &MockCertificateService{ - TriggerRenewalFn: func(certID string) error { + TriggerRenewalFn: func(_ context.Context, certID string, _ string) error { if certID == "mc-prod-001" { return nil } @@ -638,7 +638,7 @@ func TestTriggerRenewal_Success(t *testing.T) { // Test TriggerRenewal - service error func TestTriggerRenewal_ServiceError(t *testing.T) { mock := &MockCertificateService{ - TriggerRenewalFn: func(certID string) error { + TriggerRenewalFn: func(_ context.Context, certID string, _ string) error { return ErrMockServiceFailed }, } @@ -658,7 +658,7 @@ func TestTriggerRenewal_ServiceError(t *testing.T) { // Test TriggerDeployment - success case func TestTriggerDeployment_Success(t *testing.T) { mock := &MockCertificateService{ - TriggerDeploymentFn: func(certID string, targetID string) error { + TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error { if certID == "mc-prod-001" { return nil } @@ -695,7 +695,7 @@ func TestTriggerDeployment_Success(t *testing.T) { // Test TriggerDeployment - without target ID func TestTriggerDeployment_NoTargetID(t *testing.T) { mock := &MockCertificateService{ - TriggerDeploymentFn: func(certID string, targetID string) error { + TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error { // Should accept empty targetID (deploy to all) return nil }, @@ -716,7 +716,7 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) { // Test ListCertificates - invalid page parameter func TestListCertificates_InvalidPageParam(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Should default to page 1 if filter.Page == 1 { return []domain.ManagedCertificate{}, 0, nil @@ -740,7 +740,7 @@ func TestListCertificates_InvalidPageParam(t *testing.T) { // Test ListCertificates - per_page exceeds max func TestListCertificates_PerPageExceedsMax(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Should cap perPage at 500 if filter.PerPage == 50 { // defaults to 50 if > 500 return []domain.ManagedCertificate{}, 0, nil @@ -765,7 +765,7 @@ func TestListCertificates_PerPageExceedsMax(t *testing.T) { func TestRevokeCertificate_Handler_Success(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { if certID != "mc-prod-001" { t.Errorf("expected certID mc-prod-001, got %s", certID) } @@ -798,7 +798,7 @@ func TestRevokeCertificate_Handler_Success(t *testing.T) { func TestRevokeCertificate_Handler_NoBody(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { // Empty reason is OK — service defaults to "unspecified" return nil }, @@ -818,7 +818,7 @@ func TestRevokeCertificate_Handler_NoBody(t *testing.T) { func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { return fmt.Errorf("certificate is already revoked") }, } @@ -839,7 +839,7 @@ func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) { func TestRevokeCertificate_Handler_NotFound(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { return fmt.Errorf("failed to fetch certificate: not found") }, } @@ -858,7 +858,7 @@ func TestRevokeCertificate_Handler_NotFound(t *testing.T) { func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { return fmt.Errorf("invalid revocation reason: badReason") }, } @@ -922,7 +922,7 @@ func TestRevokeCertificate_Handler_EmptyID(t *testing.T) { func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { return fmt.Errorf("cannot revoke archived certificate") }, } @@ -941,7 +941,7 @@ func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) { func TestRevokeCertificate_Handler_ServerError(t *testing.T) { mock := &MockCertificateService{ - RevokeCertificateFn: func(certID string, reason string) error { + RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error { return fmt.Errorf("database connection lost") }, } @@ -962,7 +962,7 @@ func TestRevokeCertificate_Handler_ServerError(t *testing.T) { func TestGetCRL_Success(t *testing.T) { mock := &MockCertificateService{ - GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { + GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) { return []*domain.CertificateRevocation{ { ID: "rev-1", @@ -1022,7 +1022,7 @@ func TestGetCRL_Success(t *testing.T) { func TestGetCRL_Empty(t *testing.T) { mock := &MockCertificateService{ - GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { + GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) { return nil, nil }, } @@ -1047,7 +1047,7 @@ func TestGetCRL_Empty(t *testing.T) { func TestGetCRL_ServiceError(t *testing.T) { mock := &MockCertificateService{ - GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { + GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) { return nil, fmt.Errorf("revocation repository not configured") }, } @@ -1083,7 +1083,7 @@ func TestGetCRL_MethodNotAllowed(t *testing.T) { func TestGetDERCRL_Success(t *testing.T) { derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes mock := &MockCertificateService{ - GenerateDERCRLFn: func(issuerID string) ([]byte, error) { + GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) { if issuerID == "iss-local" { return derCRLData, nil } @@ -1111,7 +1111,7 @@ func TestGetDERCRL_Success(t *testing.T) { func TestGetDERCRL_IssuerNotFound(t *testing.T) { mock := &MockCertificateService{ - GenerateDERCRLFn: func(issuerID string) ([]byte, error) { + GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) { return nil, fmt.Errorf("issuer not found") }, } @@ -1130,7 +1130,7 @@ func TestGetDERCRL_IssuerNotFound(t *testing.T) { func TestGetDERCRL_NotSupported(t *testing.T) { mock := &MockCertificateService{ - GenerateDERCRLFn: func(issuerID string) ([]byte, error) { + GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) { return nil, fmt.Errorf("issuer does not support CRL generation") }, } @@ -1165,7 +1165,7 @@ func TestGetDERCRL_MethodNotAllowed(t *testing.T) { func TestHandleOCSP_Success(t *testing.T) { ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response mock := &MockCertificateService{ - GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { + GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) { if issuerID == "iss-local" && serialHex == "12345" { return ocspResponseBytes, nil } @@ -1206,7 +1206,7 @@ func TestHandleOCSP_MissingSerial(t *testing.T) { func TestHandleOCSP_IssuerNotFound(t *testing.T) { mock := &MockCertificateService{ - GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { + GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) { return nil, fmt.Errorf("issuer not found") }, } @@ -1225,7 +1225,7 @@ func TestHandleOCSP_IssuerNotFound(t *testing.T) { func TestHandleOCSP_CertNotFound(t *testing.T) { mock := &MockCertificateService{ - GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { + GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) { return nil, fmt.Errorf("certificate not found") }, } @@ -1261,7 +1261,7 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) { // TestListCertificates_SortParam tests sort parameter parsing and passing to service. func TestListCertificates_SortParam(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Handler strips the '-' prefix and sets SortDesc = true if filter.Sort != "notAfter" || !filter.SortDesc { t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc) @@ -1284,7 +1284,7 @@ func TestListCertificates_SortParam(t *testing.T) { // TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending). func TestListCertificates_SortParam_Ascending(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.Sort != "createdAt" || filter.SortDesc { t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc) } @@ -1309,7 +1309,7 @@ func TestListCertificates_TimeRangeFilters(t *testing.T) { after := time.Now().AddDate(0, 0, -90) mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.ExpiresBefore == nil { t.Error("expected ExpiresBefore to be set") } @@ -1339,7 +1339,7 @@ func TestListCertificates_CreatedAfterFilter(t *testing.T) { past := time.Now().AddDate(-1, 0, 0) mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.CreatedAfter == nil { t.Error("expected CreatedAfter to be set") } @@ -1369,7 +1369,7 @@ func TestListCertificates_CursorPagination(t *testing.T) { } mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return []domain.ManagedCertificate{cert}, 1, nil }, } @@ -1409,7 +1409,7 @@ func TestListCertificates_SparseFields(t *testing.T) { } mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if len(filter.Fields) != 2 { t.Errorf("expected 2 fields, got %d", len(filter.Fields)) } @@ -1456,7 +1456,7 @@ func TestListCertificates_SparseFields(t *testing.T) { // TestListCertificates_ProfileFilter tests profile_id filter. func TestListCertificates_ProfileFilter(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.ProfileID != "prof-standard" { t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID) } @@ -1479,7 +1479,7 @@ func TestListCertificates_ProfileFilter(t *testing.T) { // TestListCertificates_AgentIDFilter tests agent_id filter. func TestListCertificates_AgentIDFilter(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.AgentID != "agent-prod-001" { t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID) } @@ -1502,7 +1502,7 @@ func TestListCertificates_AgentIDFilter(t *testing.T) { // TestListCertificates_CombinedFilters tests multiple filters together. func TestListCertificates_CombinedFilters(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" { t.Error("expected all filters to be set") } @@ -1540,7 +1540,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) { } mock := &MockCertificateService{ - GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) { if certID != "mc-prod-001" { return nil, ErrMockNotFound } @@ -1576,7 +1576,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) { // TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate. func TestGetCertificateDeployments_NotFound(t *testing.T) { mock := &MockCertificateService{ - GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) { return nil, fmt.Errorf("certificate not found") }, } @@ -1596,7 +1596,7 @@ func TestGetCertificateDeployments_NotFound(t *testing.T) { // TestGetCertificateDeployments_Empty tests successful response with no deployments. func TestGetCertificateDeployments_Empty(t *testing.T) { mock := &MockCertificateService{ - GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) { if certID == "mc-no-deployments" { return []domain.DeploymentTarget{}, nil } diff --git a/internal/api/handler/certificates.go b/internal/api/handler/certificates.go index b8e17cb..365fa35 100644 --- a/internal/api/handler/certificates.go +++ b/internal/api/handler/certificates.go @@ -1,6 +1,7 @@ package handler import ( + "context" "encoding/json" "log/slog" "net/http" @@ -15,20 +16,20 @@ import ( // CertificateService defines the service interface for certificate operations. type CertificateService interface { - ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) - ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) - GetCertificate(id string) (*domain.ManagedCertificate, error) - CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - ArchiveCertificate(id string) error - GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) - TriggerRenewal(certID string) error - TriggerDeployment(certID string, targetID string) error - RevokeCertificate(certID string, reason string) error - GetRevokedCertificates() ([]*domain.CertificateRevocation, error) - GenerateDERCRL(issuerID string) ([]byte, error) - GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) - GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) + ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) + ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) + GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) + CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + ArchiveCertificate(ctx context.Context, id string) error + GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) + TriggerRenewal(ctx context.Context, certID string, actor string) error + TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error + RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error + GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) + GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) + GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) + GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) } // CertificateHandler handles HTTP requests for certificate operations. @@ -128,7 +129,7 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ filter.Fields = strings.Split(fieldsStr, ",") } - certs, total, err := h.svc.ListCertificatesWithFilter(filter) + certs, total, err := h.svc.ListCertificatesWithFilter(r.Context(), filter) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID) return @@ -186,7 +187,7 @@ func (h CertificateHandler) GetCertificate(w http.ResponseWriter, r *http.Reques return } - cert, err := h.svc.GetCertificate(id) + cert, err := h.svc.GetCertificate(r.Context(), id) if err != nil { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) return @@ -241,7 +242,7 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req return } - created, err := h.svc.CreateCertificate(cert) + created, err := h.svc.CreateCertificate(r.Context(), cert) if err != nil { slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID) @@ -295,7 +296,7 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req } } - updated, err := h.svc.UpdateCertificate(id, cert) + updated, err := h.svc.UpdateCertificate(r.Context(), id, cert) if err != nil { if strings.Contains(err.Error(), "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) @@ -325,7 +326,7 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re return } - if err := h.svc.ArchiveCertificate(id); err != nil { + if err := h.svc.ArchiveCertificate(r.Context(), id); err != nil { if strings.Contains(err.Error(), "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) return @@ -370,7 +371,7 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt } } - versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage) + versions, total, err := h.svc.GetCertificateVersions(r.Context(), certID, page, perPage) if err != nil { if strings.Contains(err.Error(), "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) @@ -410,7 +411,7 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques } certID := parts[0] - if err := h.svc.TriggerRenewal(certID); err != nil { + if err := h.svc.TriggerRenewal(r.Context(), certID, "api"); err != nil { errMsg := err.Error() if strings.Contains(errMsg, "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) @@ -466,7 +467,7 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req } } - if err := h.svc.TriggerDeployment(certID, req.TargetID); err != nil { + if err := h.svc.TriggerDeployment(r.Context(), certID, req.TargetID, "api"); err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID) return } @@ -508,7 +509,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req } } - if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil { + if err := h.svc.RevokeCertificate(r.Context(), certID, req.Reason, "api"); err != nil { // Distinguish between client errors and server errors errMsg := err.Error() if strings.Contains(errMsg, "already revoked") || @@ -540,7 +541,7 @@ func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) { requestID := middleware.GetRequestID(r.Context()) - revocations, err := h.svc.GetRevokedCertificates() + revocations, err := h.svc.GetRevokedCertificates(r.Context()) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID) return @@ -585,7 +586,7 @@ func (h CertificateHandler) GetDERCRL(w http.ResponseWriter, r *http.Request) { return } - derBytes, err := h.svc.GenerateDERCRL(issuerID) + derBytes, err := h.svc.GenerateDERCRL(r.Context(), issuerID) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "not found") { @@ -627,7 +628,7 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) { issuerID := parts[0] serialHex := parts[1] - derBytes, err := h.svc.GetOCSPResponse(issuerID, serialHex) + derBytes, err := h.svc.GetOCSPResponse(r.Context(), issuerID, serialHex) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "not found") { @@ -667,7 +668,7 @@ func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r * } certID := parts[0] - deployments, err := h.svc.GetCertificateDeployments(certID) + deployments, err := h.svc.GetCertificateDeployments(r.Context(), certID) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "not found") { diff --git a/internal/service/ca_operations.go b/internal/service/ca_operations.go index 835c6cd..430a664 100644 --- a/internal/service/ca_operations.go +++ b/internal/service/ca_operations.go @@ -41,7 +41,7 @@ func (s *CAOperationsSvc) SetIssuerRegistry(registry *IssuerRegistry) { // GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer. // Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL. -func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) { +func (s *CAOperationsSvc) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) { if s.revocationRepo == nil { return nil, fmt.Errorf("revocation repository not configured") } @@ -54,7 +54,7 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) { return nil, fmt.Errorf("issuer not found: %s", issuerID) } - revocations, err := s.revocationRepo.ListAll(context.Background()) + revocations, err := s.revocationRepo.ListAll(ctx) if err != nil { return nil, fmt.Errorf("failed to list revocations: %w", err) } @@ -69,9 +69,9 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) { // Check short-lived exemption: look up the cert's profile if s.profileRepo != nil && s.certRepo != nil { - cert, err := s.certRepo.Get(context.Background(), rev.CertificateID) + cert, err := s.certRepo.Get(ctx, rev.CertificateID) if err == nil && cert.CertificateProfileID != "" { - profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID) + profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID) if err == nil && profile.IsShortLived() { slog.Debug("skipping short-lived cert from CRL", "certificate_id", rev.CertificateID, @@ -92,11 +92,11 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) { }) } - return issuerConn.GenerateCRL(context.Background(), entries) + return issuerConn.GenerateCRL(ctx, entries) } // GetOCSPResponse generates a signed OCSP response for the given certificate serial. -func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { +func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) { if s.revocationRepo == nil { return nil, fmt.Errorf("revocation repository not configured") } @@ -120,13 +120,13 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([] // Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers // are unique only within a single issuer. The OCSP URL path carries issuer_id, // so we scope the lookup to avoid cross-issuer collisions. - rev, _ := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex) + rev, _ := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) if rev != nil { - cert, err := s.certRepo.Get(context.Background(), rev.CertificateID) + cert, err := s.certRepo.Get(ctx, rev.CertificateID) if err == nil && cert.CertificateProfileID != "" { - profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID) + profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID) if err == nil && profile.IsShortLived() { - return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ + return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ CertSerial: serial, CertStatus: 0, // good — short-lived exemption ThisUpdate: now, @@ -138,10 +138,10 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([] } // Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping. - rev, err := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex) + rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex) if err != nil { // Not revoked — return "good" status - return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ + return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ CertSerial: serial, CertStatus: 0, // good ThisUpdate: now, @@ -150,7 +150,7 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([] } // Revoked - return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ + return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{ CertSerial: serial, CertStatus: 1, // revoked RevokedAt: rev.RevokedAt, diff --git a/internal/service/ca_operations_test.go b/internal/service/ca_operations_test.go index 59845e2..2825fd5 100644 --- a/internal/service/ca_operations_test.go +++ b/internal/service/ca_operations_test.go @@ -3,6 +3,7 @@ package service import ( + "context" "log/slog" "testing" "time" @@ -48,7 +49,7 @@ func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) { }, } - crl, err := caSvc.GenerateDERCRL("iss-local") + crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -71,7 +72,7 @@ func TestCAOperationsSvc_GenerateDERCRL_EmptyCRL(t *testing.T) { // No revoked certs for this issuer revocationRepo.Revocations = []*domain.CertificateRevocation{} - crl, err := caSvc.GenerateDERCRL("iss-local") + crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -112,7 +113,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) { certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version} // Request OCSP response for good cert - resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-GOOD-001") + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -165,7 +166,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) { } // Request OCSP response for revoked cert - resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001") + resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001") if err != nil { t.Fatalf("expected no error, got: %v", err) diff --git a/internal/service/certificate.go b/internal/service/certificate.go index 9744ce5..b8c6db2 100644 --- a/internal/service/certificate.go +++ b/internal/service/certificate.go @@ -71,8 +71,8 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif // ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20). // This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers). -func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { - certs, total, err := s.certRepo.List(context.Background(), filter) +func (s *CertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + certs, total, err := s.certRepo.List(ctx, filter) if err != nil { return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err) } @@ -206,10 +206,10 @@ func (s *CertificateService) GetVersions(ctx context.Context, certID string) ([] return versions, nil } -// TriggerRenewalWithActor initiates a renewal job if the certificate is eligible. +// TriggerRenewal initiates a renewal job if the certificate is eligible. // Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor // can pick it up and route it through the issuer connector. -func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID string, actor string) error { +func (s *CertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error { cert, err := s.certRepo.Get(ctx, certID) if err != nil { return fmt.Errorf("failed to fetch certificate: %w", err) @@ -283,8 +283,11 @@ func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID return nil } -// TriggerDeploymentWithActor creates deployment jobs for all targets of a certificate. -func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, certID string, actor string) error { +// TriggerDeployment creates deployment jobs for all targets of a certificate. +// The targetID parameter is accepted from the handler interface but currently unused; +// deployment coordination happens per-certificate across all of its targets. +func (s *CertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error { + _ = targetID cert, err := s.certRepo.Get(ctx, certID) if err != nil { return fmt.Errorf("failed to fetch certificate: %w", err) @@ -306,7 +309,7 @@ func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, cer } // ListCertificates returns paginated certificates with optional filtering (handler interface method). -func (s *CertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { +func (s *CertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { if page < 1 { page = 1 } @@ -325,7 +328,7 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team PerPage: perPage, } - certs, total, err := s.certRepo.List(context.Background(), filter) + certs, total, err := s.certRepo.List(ctx, filter) if err != nil { return nil, 0, fmt.Errorf("failed to list certificates: %w", err) } @@ -341,12 +344,12 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team } // GetCertificate returns a single certificate (handler interface method). -func (s *CertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) { - return s.certRepo.Get(context.Background(), id) +func (s *CertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) { + return s.certRepo.Get(ctx, id) } // CreateCertificate creates a new certificate (handler interface method). -func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { +func (s *CertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { if cert.ID == "" { cert.ID = generateID("cert") } @@ -365,16 +368,14 @@ func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) ( if cert.Tags == nil { cert.Tags = make(map[string]string) } - if err := s.certRepo.Create(context.Background(), &cert); err != nil { + if err := s.certRepo.Create(ctx, &cert); err != nil { return nil, fmt.Errorf("failed to create certificate: %w", err) } return &cert, nil } // UpdateCertificate modifies a certificate (handler interface method). -func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) { - ctx := context.Background() - +func (s *CertificateService) UpdateCertificate(ctx context.Context, id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) { // Fetch existing certificate so partial updates don't zero out fields existing, err := s.certRepo.Get(ctx, id) if err != nil { @@ -425,12 +426,12 @@ func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCe } // ArchiveCertificate marks a certificate as archived (handler interface method). -func (s *CertificateService) ArchiveCertificate(id string) error { - return s.certRepo.Archive(context.Background(), id) +func (s *CertificateService) ArchiveCertificate(ctx context.Context, id string) error { + return s.certRepo.Archive(ctx, id) } // GetCertificateVersions returns certificate versions (handler interface method). -func (s *CertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { +func (s *CertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { if page < 1 { page = 1 } @@ -438,7 +439,7 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage perPage = 50 } - versions, err := s.certRepo.ListVersions(context.Background(), certID) + versions, err := s.certRepo.ListVersions(ctx, certID) if err != nil { return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err) } @@ -463,24 +464,8 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage return result, total, nil } -// TriggerRenewal initiates renewal (handler interface method). -func (s *CertificateService) TriggerRenewal(certID string) error { - return s.TriggerRenewalWithActor(context.Background(), certID, "api") -} - -// TriggerDeployment triggers deployment (handler interface method). -func (s *CertificateService) TriggerDeployment(certID string, targetID string) error { - return s.TriggerDeploymentWithActor(context.Background(), certID, "api") -} - -// RevokeCertificate revokes a certificate with the given reason (handler interface method). -func (s *CertificateService) RevokeCertificate(certID string, reason string) error { - return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api") -} - -// RevokeCertificateWithActor performs revocation with actor tracking. -// Delegates to RevocationSvc. -func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error { +// RevokeCertificate performs revocation with actor tracking. Delegates to RevocationSvc. +func (s *CertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error { if s.revSvc == nil { return fmt.Errorf("revocation service not configured") } @@ -489,35 +474,35 @@ func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, cer // GetRevokedCertificates returns all revoked certificate records (for CRL generation). // Delegates to RevocationSvc. -func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { +func (s *CertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) { if s.revSvc == nil { return nil, fmt.Errorf("revocation service not configured") } - return s.revSvc.GetRevokedCertificates() + return s.revSvc.GetRevokedCertificates(ctx) } // GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer. // Delegates to CAOperationsSvc. -func (s *CertificateService) GenerateDERCRL(issuerID string) ([]byte, error) { +func (s *CertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) { if s.caSvc == nil { return nil, fmt.Errorf("CA operations service not configured") } - return s.caSvc.GenerateDERCRL(issuerID) + return s.caSvc.GenerateDERCRL(ctx, issuerID) } // GetOCSPResponse generates a signed OCSP response for the given certificate serial. // Delegates to CAOperationsSvc. -func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { +func (s *CertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) { if s.caSvc == nil { return nil, fmt.Errorf("CA operations service not configured") } - return s.caSvc.GetOCSPResponse(issuerID, serialHex) + return s.caSvc.GetOCSPResponse(ctx, issuerID, serialHex) } // GetCertificateDeployments returns all deployment targets for a certificate (M20). -func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { +func (s *CertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) { // Verify certificate exists - _, err := s.certRepo.Get(context.Background(), certID) + _, err := s.certRepo.Get(ctx, certID) if err != nil { return nil, fmt.Errorf("certificate not found: %w", err) } @@ -527,7 +512,7 @@ func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain. } // Get targets from repository - targets, err := s.targetRepo.ListByCertificate(context.Background(), certID) + targets, err := s.targetRepo.ListByCertificate(ctx, certID) if err != nil { return nil, fmt.Errorf("failed to list deployment targets: %w", err) } diff --git a/internal/service/certificate_nil_safety_test.go b/internal/service/certificate_nil_safety_test.go index 9d86f1f..1134c86 100644 --- a/internal/service/certificate_nil_safety_test.go +++ b/internal/service/certificate_nil_safety_test.go @@ -34,7 +34,7 @@ func TestCertificateService_RevokeCertificate_RevocationSvcNil(t *testing.T) { certRepo.AddCert(cert) // Call RevokeCertificateWithActor with nil RevocationSvc - err := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") + err := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin") // Assert: Should return error, NOT panic if err == nil { @@ -64,7 +64,7 @@ func TestCertificateService_GenerateDERCRL_CAOpsSvcNil(t *testing.T) { // Note: NOT calling certService.SetCAOperationsSvc(...) // Call GenerateDERCRL with nil CAOperationsSvc - _, err := certService.GenerateDERCRL("iss-local") + _, err := certService.GenerateDERCRL(context.Background(), "iss-local") // Assert: Should return error, NOT panic if err == nil { @@ -94,7 +94,7 @@ func TestCertificateService_GetOCSPResponse_CAOpsSvcNil(t *testing.T) { // Note: NOT calling certService.SetCAOperationsSvc(...) // Call GetOCSPResponse with nil CAOperationsSvc - _, err := certService.GetOCSPResponse("iss-local", "serial123") + _, err := certService.GetOCSPResponse(context.Background(), "iss-local", "serial123") // Assert: Should return error, NOT panic if err == nil { @@ -124,7 +124,7 @@ func TestCertificateService_GetRevokedCertificates_RevocationSvcNil(t *testing.T // Note: NOT calling certService.SetRevocationSvc(...) // Call GetRevokedCertificates with nil RevocationSvc - _, err := certService.GetRevokedCertificates() + _, err := certService.GetRevokedCertificates(context.Background()) // Assert: Should return error, NOT panic if err == nil { @@ -177,7 +177,7 @@ func TestCertificateService_GetCertificateDeployments_Success(t *testing.T) { targetRepo.AddTarget(target2) // Call GetCertificateDeployments - deployments, err := certService.GetCertificateDeployments("cert-1") + deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1") // Assert: Should return deployment list successfully if err != nil { @@ -218,7 +218,7 @@ func TestCertificateService_GetCertificateDeployments_RepositoryError(t *testing certRepo.AddCert(cert) // Call GetCertificateDeployments with repo error - _, err := certService.GetCertificateDeployments("cert-1") + _, err := certService.GetCertificateDeployments(context.Background(), "cert-1") // Assert: Should return error, NOT panic if err == nil { @@ -247,7 +247,7 @@ func TestCertificateService_GetCertificateDeployments_CertNotFound(t *testing.T) certService.SetTargetRepo(targetRepo) // Call GetCertificateDeployments with nonexistent certificate - _, err := certService.GetCertificateDeployments("nonexistent-cert") + _, err := certService.GetCertificateDeployments(context.Background(), "nonexistent-cert") // Assert: Should return error if err == nil { @@ -283,7 +283,7 @@ func TestCertificateService_GetCertificateDeployments_NilTargetRepo(t *testing.T certRepo.AddCert(cert) // Call GetCertificateDeployments with nil TargetRepo - deployments, err := certService.GetCertificateDeployments("cert-1") + deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1") // Assert: Should return empty list gracefully (not panic) if err != nil { @@ -337,19 +337,19 @@ func TestCertificateService_Multiple_NilSafetyChecks(t *testing.T) { revSvc.SetIssuerRegistry(registry) // Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set) - errRevoke := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") + errRevoke := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin") if errRevoke != nil { t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke) } // Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil) - _, errCRL := certService.GenerateDERCRL("iss-local") + _, errCRL := certService.GenerateDERCRL(context.Background(), "iss-local") if errCRL == nil { t.Fatal("GenerateDERCRL expected error, got nil") } // Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil) - _, errOCSP := certService.GetOCSPResponse("iss-local", "ABC123") + _, errOCSP := certService.GetOCSPResponse(context.Background(), "iss-local", "ABC123") if errOCSP == nil { t.Fatal("GetOCSPResponse expected error, got nil") } diff --git a/internal/service/certificate_test.go b/internal/service/certificate_test.go index bb111aa..2bea43a 100644 --- a/internal/service/certificate_test.go +++ b/internal/service/certificate_test.go @@ -294,7 +294,7 @@ func TestTriggerRenewal(t *testing.T) { auditService := NewAuditService(auditRepo) certService := NewCertificateService(certRepo, policyService, auditService) - err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1") + err := certService.TriggerRenewal(ctx, "cert-001", "user-1") if err != nil { t.Fatalf("TriggerRenewal failed: %v", err) } @@ -333,13 +333,14 @@ func TestTriggerRenewal_Archived(t *testing.T) { auditService := NewAuditService(auditRepo) certService := NewCertificateService(certRepo, policyService, auditService) - err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1") + err := certService.TriggerRenewal(ctx, "cert-001", "user-1") if err == nil { t.Fatal("expected error for archived certificate") } } func TestListCertificates(t *testing.T) { + ctx := context.Background() now := time.Now() cert1 := &domain.ManagedCertificate{ ID: "cert-001", @@ -369,7 +370,7 @@ func TestListCertificates(t *testing.T) { auditService := NewAuditService(auditRepo) certService := NewCertificateService(certRepo, policyService, auditService) - certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50) + certs, total, err := certService.ListCertificates(ctx, "", "", "", "", "", 1, 50) if err != nil { t.Fatalf("ListCertificates failed: %v", err) } diff --git a/internal/service/revocation_svc.go b/internal/service/revocation_svc.go index 2456082..660f608 100644 --- a/internal/service/revocation_svc.go +++ b/internal/service/revocation_svc.go @@ -151,9 +151,9 @@ func (s *RevocationSvc) RevokeCertificateWithActor(ctx context.Context, certID s } // GetRevokedCertificates returns all revoked certificate records (for CRL generation). -func (s *RevocationSvc) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { +func (s *RevocationSvc) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) { if s.revocationRepo == nil { return nil, fmt.Errorf("revocation repository not configured") } - return s.revocationRepo.ListAll(context.Background()) + return s.revocationRepo.ListAll(ctx) } diff --git a/internal/service/revocation_svc_test.go b/internal/service/revocation_svc_test.go index 7016789..f7ef15d 100644 --- a/internal/service/revocation_svc_test.go +++ b/internal/service/revocation_svc_test.go @@ -122,7 +122,7 @@ func TestRevocationSvc_GetRevokedCertificates_Success(t *testing.T) { {ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()}, } - revocations, err := revSvc.GetRevokedCertificates() + revocations, err := revSvc.GetRevokedCertificates(context.Background()) if err != nil { t.Fatalf("expected no error, got: %v", err) } diff --git a/internal/service/revocation_test.go b/internal/service/revocation_test.go index dcfd459..4b673fd 100644 --- a/internal/service/revocation_test.go +++ b/internal/service/revocation_test.go @@ -62,7 +62,7 @@ func TestRevokeCertificate_Success(t *testing.T) { certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version} // Revoke - err := svc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -125,7 +125,7 @@ func TestRevokeCertificate_DefaultReason(t *testing.T) { } // Revoke with empty reason — should default to "unspecified" - err := svc.RevokeCertificateWithActor(context.Background(), "cert-2", "", "api") + err := svc.RevokeCertificate(context.Background(), "cert-2", "", "api") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -158,7 +158,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) { } certRepo.AddCert(cert) - err := svc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-3", "superseded", "admin") if err == nil { t.Fatal("expected error for already revoked certificate") } @@ -179,7 +179,7 @@ func TestRevokeCertificate_ArchivedCert(t *testing.T) { } certRepo.AddCert(cert) - err := svc.RevokeCertificateWithActor(context.Background(), "cert-4", "keyCompromise", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-4", "keyCompromise", "admin") if err == nil { t.Fatal("expected error for archived certificate") } @@ -200,7 +200,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) { } certRepo.AddCert(cert) - err := svc.RevokeCertificateWithActor(context.Background(), "cert-5", "notAValidReason", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-5", "notAValidReason", "admin") if err == nil { t.Fatal("expected error for invalid reason") } @@ -212,7 +212,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) { func TestRevokeCertificate_NotFound(t *testing.T) { svc, _, _, _ := newRevocationTestService() - err := svc.RevokeCertificateWithActor(context.Background(), "nonexistent-cert", "keyCompromise", "admin") + err := svc.RevokeCertificate(context.Background(), "nonexistent-cert", "keyCompromise", "admin") if err == nil { t.Fatal("expected error for nonexistent certificate") } @@ -231,7 +231,7 @@ func TestRevokeCertificate_NoVersion(t *testing.T) { certRepo.AddCert(cert) // No versions added — should fail - err := svc.RevokeCertificateWithActor(context.Background(), "cert-6", "keyCompromise", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-6", "keyCompromise", "admin") if err == nil { t.Fatal("expected error when no certificate version exists") } @@ -258,7 +258,7 @@ func TestRevokeCertificate_WithIssuerNotification(t *testing.T) { {ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()}, } - err := svc.RevokeCertificateWithActor(context.Background(), "cert-7", "cessationOfOperation", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-7", "cessationOfOperation", "admin") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -293,7 +293,7 @@ func TestRevokeCertificate_WithNotificationService(t *testing.T) { {ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()}, } - err := svc.RevokeCertificateWithActor(context.Background(), "cert-8", "keyCompromise", "admin") + err := svc.RevokeCertificate(context.Background(), "cert-8", "keyCompromise", "admin") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -336,7 +336,7 @@ func TestRevokeCertificate_AllValidReasons(t *testing.T) { {ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()}, } - err := svc.RevokeCertificateWithActor(context.Background(), "cert-"+reason, reason, "admin") + err := svc.RevokeCertificate(context.Background(), "cert-"+reason, reason, "admin") if err != nil { t.Fatalf("expected no error for reason %s, got: %v", reason, err) } @@ -358,7 +358,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) { {ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()}, } - revocations, err := svc.GetRevokedCertificates() + revocations, err := svc.GetRevokedCertificates(context.Background()) if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -370,7 +370,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) { func TestGetRevokedCertificates_Empty(t *testing.T) { svc, _, _, _ := newRevocationTestService() - revocations, err := svc.GetRevokedCertificates() + revocations, err := svc.GetRevokedCertificates(context.Background()) if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -390,7 +390,7 @@ func TestGetRevokedCertificates_NoRepo(t *testing.T) { svc := NewCertificateService(certRepo, policyService, auditService) // Do NOT set revocation repo - _, err := svc.GetRevokedCertificates() + _, err := svc.GetRevokedCertificates(context.Background()) if err == nil { t.Fatal("expected error when revocation repo not configured") } @@ -411,8 +411,8 @@ func TestRevokeCertificate_HandlerInterfaceMethod(t *testing.T) { {ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()}, } - // Test the handler interface method (no actor param) - err := svc.RevokeCertificate("cert-handler", "superseded") + // Test the handler interface method (actor collapsed to required positional arg per D-2) + err := svc.RevokeCertificate(context.Background(), "cert-handler", "superseded", "api") if err != nil { t.Fatalf("expected no error, got: %v", err) } @@ -449,7 +449,7 @@ func TestGenerateDERCRL_Success(t *testing.T) { }, } - crl, err := svc.GenerateDERCRL("iss-local") + crl, err := svc.GenerateDERCRL(context.Background(), "iss-local") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -472,7 +472,7 @@ func TestGenerateDERCRL_EmptyCRL(t *testing.T) { // No revoked certs for this issuer revocationRepo.Revocations = []*domain.CertificateRevocation{} - crl, err := svc.GenerateDERCRL("iss-local") + crl, err := svc.GenerateDERCRL(context.Background(), "iss-local") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -493,7 +493,7 @@ func TestGenerateDERCRL_IssuerNotFound(t *testing.T) { svc, _, _, _ := newRevocationTestService() // Try to generate CRL for unknown issuer - crl, err := svc.GenerateDERCRL("iss-unknown") + crl, err := svc.GenerateDERCRL(context.Background(), "iss-unknown") // Should return error or nil CRL depending on implementation if crl != nil && err == nil { @@ -527,7 +527,7 @@ func TestGetOCSPResponse_Good(t *testing.T) { certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version} // Request OCSP response for good cert - resp, err := svc.GetOCSPResponse("iss-local", "OCSP-GOOD-001") + resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -580,7 +580,7 @@ func TestGetOCSPResponse_Revoked(t *testing.T) { } // Request OCSP response for revoked cert - resp, err := svc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001") + resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001") if err != nil { t.Fatalf("expected no error, got: %v", err) @@ -597,7 +597,7 @@ func TestGetOCSPResponse_Unknown(t *testing.T) { svc, _, _, _ := newRevocationTestService() // Request OCSP response for unknown cert - resp, err := svc.GetOCSPResponse("iss-local", "UNKNOWN-SERIAL") + resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "UNKNOWN-SERIAL") if err != nil { t.Fatalf("expected no error (should return unknown status), got: %v", err) @@ -615,7 +615,7 @@ func TestGetOCSPResponse_IssuerNotFound(t *testing.T) { svc, _, _, _ := newRevocationTestService() // Request OCSP response for unknown issuer - resp, err := svc.GetOCSPResponse("iss-unknown", "SOME-SERIAL") + resp, err := svc.GetOCSPResponse(context.Background(), "iss-unknown", "SOME-SERIAL") // Should return error since issuer doesn't exist if err == nil && resp != nil { @@ -629,7 +629,7 @@ func TestGetOCSPResponse_InvalidSerial(t *testing.T) { svc, _, _, _ := newRevocationTestService() // Request OCSP response with invalid serial format - resp, err := svc.GetOCSPResponse("iss-local", "") + resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "") if err == nil && resp != nil { // Empty serial might return unknown status; that's ok