fix(m-2): thread context through CertificateService cluster

Collapses CertificateService, RevocationSvc, and CAOperationsSvc to
ctx-accepting method signatures. Removes context.Background() synthesis
at 24 internal call sites across certificate.go, revocation_svc.go, and
ca_operations.go.

- Primary repo calls inherit request cancellation via the passed ctx.
- Audit and notification dispatches use context.WithoutCancel(ctx) so
  they survive client disconnect.
- Collapses TriggerRenewal/TriggerRenewalWithActor,
  TriggerDeployment/TriggerDeploymentWithActor, and
  RevokeCertificate/RevokeCertificateWithActor sibling pairs into single
  canonical ctx-accepting methods (decisions D-1, D-2).

Handlers pass r.Context(). Mocks and tests updated to match new
signatures. No HTTP surface change, no OpenAPI change.

PR 1 of 6 in the M-2 remediation chain. Master green at this commit.

Refs: certctl-audit-report.md M-2 (L143, L224)
This commit is contained in:
shankar0123
2026-04-18 00:29:37 +00:00
parent e951d319d0
commit cdc9d03d5b
12 changed files with 225 additions and 235 deletions
+28 -27
View File
@@ -1,6 +1,7 @@
package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
@@ -15,20 +16,20 @@ import (
// CertificateService defines the service interface for certificate operations.
type CertificateService interface {
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificate(id string) (*domain.ManagedCertificate, error)
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificate(id string) error
GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(certID string) error
TriggerDeployment(certID string, targetID string) error
RevokeCertificate(certID string, reason string) error
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
GenerateDERCRL(issuerID string) ([]byte, error)
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error)
GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error)
ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error)
CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificate(ctx context.Context, id string) error
GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(ctx context.Context, certID string, actor string) error
TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error
RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error
GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error)
GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error)
GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
}
// CertificateHandler handles HTTP requests for certificate operations.
@@ -128,7 +129,7 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
filter.Fields = strings.Split(fieldsStr, ",")
}
certs, total, err := h.svc.ListCertificatesWithFilter(filter)
certs, total, err := h.svc.ListCertificatesWithFilter(r.Context(), filter)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
return
@@ -186,7 +187,7 @@ func (h CertificateHandler) GetCertificate(w http.ResponseWriter, r *http.Reques
return
}
cert, err := h.svc.GetCertificate(id)
cert, err := h.svc.GetCertificate(r.Context(), id)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
@@ -241,7 +242,7 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req
return
}
created, err := h.svc.CreateCertificate(cert)
created, err := h.svc.CreateCertificate(r.Context(), cert)
if err != nil {
slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name)
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID)
@@ -295,7 +296,7 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req
}
}
updated, err := h.svc.UpdateCertificate(id, cert)
updated, err := h.svc.UpdateCertificate(r.Context(), id, cert)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -325,7 +326,7 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re
return
}
if err := h.svc.ArchiveCertificate(id); err != nil {
if err := h.svc.ArchiveCertificate(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
@@ -370,7 +371,7 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt
}
}
versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage)
versions, total, err := h.svc.GetCertificateVersions(r.Context(), certID, page, perPage)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -410,7 +411,7 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques
}
certID := parts[0]
if err := h.svc.TriggerRenewal(certID); err != nil {
if err := h.svc.TriggerRenewal(r.Context(), certID, "api"); err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -466,7 +467,7 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req
}
}
if err := h.svc.TriggerDeployment(certID, req.TargetID); err != nil {
if err := h.svc.TriggerDeployment(r.Context(), certID, req.TargetID, "api"); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID)
return
}
@@ -508,7 +509,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req
}
}
if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil {
if err := h.svc.RevokeCertificate(r.Context(), certID, req.Reason, "api"); err != nil {
// Distinguish between client errors and server errors
errMsg := err.Error()
if strings.Contains(errMsg, "already revoked") ||
@@ -540,7 +541,7 @@ func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) {
requestID := middleware.GetRequestID(r.Context())
revocations, err := h.svc.GetRevokedCertificates()
revocations, err := h.svc.GetRevokedCertificates(r.Context())
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
return
@@ -585,7 +586,7 @@ func (h CertificateHandler) GetDERCRL(w http.ResponseWriter, r *http.Request) {
return
}
derBytes, err := h.svc.GenerateDERCRL(issuerID)
derBytes, err := h.svc.GenerateDERCRL(r.Context(), issuerID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
@@ -627,7 +628,7 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) {
issuerID := parts[0]
serialHex := parts[1]
derBytes, err := h.svc.GetOCSPResponse(issuerID, serialHex)
derBytes, err := h.svc.GetOCSPResponse(r.Context(), issuerID, serialHex)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
@@ -667,7 +668,7 @@ func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *
}
certID := parts[0]
deployments, err := h.svc.GetCertificateDeployments(certID)
deployments, err := h.svc.GetCertificateDeployments(r.Context(), certID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {