feat: M15a — certificate revocation API, CRL endpoint, and revocation notifications

Implements core revocation infrastructure: POST /api/v1/certificates/{id}/revoke
with all 8 RFC 5280 reason codes, JSON-formatted CRL at GET /api/v1/crl, webhook
and email revocation notifications, best-effort issuer notification, and immutable
revocation audit trail. Includes 48 new tests across service, handler, integration,
and domain layers (600+ total). Fixes 3 pre-existing test bugs (team_test error
matching, agent_group delete status code, team handler per_page validation).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shankar
2026-03-22 10:59:18 -04:00
parent f971662302
commit 5cd9e890f4
27 changed files with 1710 additions and 37 deletions
+8
View File
@@ -124,12 +124,20 @@ func main() {
}
logger.Info("issuer registry configured", "issuers", len(issuerRegistry))
// Initialize revocation repository
revocationRepo := postgres.NewRevocationRepository(db)
// Initialize services (following the dependency graph)
auditService := service.NewAuditService(auditRepo)
policyService := service.NewPolicyService(policyRepo, auditService)
certificateService := service.NewCertificateService(certificateRepo, policyService, auditService)
notificationService := service.NewNotificationService(notificationRepo, make(map[string]service.Notifier))
notificationService.SetOwnerRepo(ownerRepo)
// Wire revocation dependencies into CertificateService
certificateService.SetRevocationRepo(revocationRepo)
certificateService.SetNotificationService(notificationService)
certificateService.SetIssuerRegistry(issuerRegistry)
renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode)
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService)
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
+6 -4
View File
@@ -563,6 +563,8 @@ Resources: certificates, issuers, targets, agents, jobs, policies, profiles, tea
Jobs support additional action endpoints: `POST /api/v1/jobs/{id}/cancel`, `POST /api/v1/jobs/{id}/approve`, `POST /api/v1/jobs/{id}/reject`.
Certificate revocation: `POST /api/v1/certificates/{id}/revoke` with optional `{"reason": "keyCompromise"}`. Supports RFC 5280 reason codes (unspecified, keyCompromise, caCompromise, affiliationChanged, superseded, cessationOfOperation, certificateHold, privilegeWithdrawn). Returns the updated certificate status. Best-effort issuer notification — the revocation succeeds even if the issuer connector is unavailable. A JSON-formatted CRL is available at `GET /api/v1/crl` (DER-encoded X.509 CRL planned for M15b).
Health checks live outside the API prefix: `GET /health` and `GET /ready`.
## Deployment Topologies
@@ -615,13 +617,13 @@ For production, you would also add an ingress controller, TLS termination for th
## Testing Strategy
certctl uses a layered testing approach aligned with the handler → service → repository architecture, with 525+ tests across five layers (service, handler, integration, connector, and frontend). The goal is high-confidence regression prevention at the service and handler layers, where the most complex business logic lives, combined with integration tests that exercise the full request path from HTTP to database.
certctl uses a layered testing approach aligned with the handler → service → repository architecture, with 600+ tests across five layers (service, handler, integration, connector, and frontend). The goal is high-confidence regression prevention at the service and handler layers, where the most complex business logic lives, combined with integration tests that exercise the full request path from HTTP to database.
**Service layer unit tests** (`internal/service/*_test.go`) — 192 test functions across 14 files with mock repositories. These test all business logic in isolation: certificate CRUD with validation, agent lifecycle (registration, heartbeat, CSR submission with both keygen modes), job state machine (creation, processing, cancellation, retry logic), policy evaluation (all 5 rule types, violation creation), renewal and issuance flow (server-side and agent-side keygen paths), notification deduplication (threshold tag matching, channel routing), team/owner/agent group CRUD with pagination and audit recording, issuer service CRUD with connection testing, and the issuer connector adapter (type translation between connector and service layers). Mock repositories are simple structs with function fields, avoiding heavy mocking frameworks — this keeps tests readable and avoids coupling to mock library APIs.
**Service layer unit tests** (`internal/service/*_test.go`) — 207 test functions across 15 files with mock repositories. These test all business logic in isolation: certificate CRUD with validation, certificate revocation (success, already-revoked, archived, invalid reason, all RFC 5280 reason codes, issuer notification, notification service integration), agent lifecycle (registration, heartbeat, CSR submission with both keygen modes), job state machine (creation, processing, cancellation, retry logic), policy evaluation (all 5 rule types, violation creation), renewal and issuance flow (server-side and agent-side keygen paths), notification deduplication (threshold tag matching, channel routing), team/owner/agent group CRUD with pagination and audit recording, issuer service CRUD with connection testing, and the issuer connector adapter (type translation between connector and service layers including revocation). Mock repositories are simple structs with function fields, avoiding heavy mocking frameworks — this keeps tests readable and avoids coupling to mock library APIs.
**Handler layer tests** (`internal/api/handler/*_test.go`) — 212 test functions across 11 files using Go's `httptest` package. Every handler file has a corresponding test file: certificates (22 tests), agents (28 tests), jobs (21 tests including approve/reject), notifications (11 tests), policies (19 tests), profiles (18 tests), issuers (17 tests), targets (17 tests), agent groups (12 tests), teams (26 tests), and owners (21 tests). Each test file follows the same pattern: a mock service struct with function fields, `httptest.NewRecorder` for capturing responses, and a shared `contextWithRequestID()` helper. Tests cover the happy path, input validation (missing fields, invalid JSON, empty IDs, name length limits), error propagation from the service layer, method-not-allowed responses, and pagination parameters.
**Handler layer tests** (`internal/api/handler/*_test.go`) — 226 test functions across 11 files using Go's `httptest` package. Every handler file has a corresponding test file: certificates (36 tests including revocation and CRL), agents (28 tests), jobs (21 tests including approve/reject), notifications (11 tests), policies (19 tests), profiles (18 tests), issuers (17 tests), targets (17 tests), agent groups (12 tests), teams (26 tests), and owners (21 tests). Each test file follows the same pattern: a mock service struct with function fields, `httptest.NewRecorder` for capturing responses, and a shared `contextWithRequestID()` helper. Tests cover the happy path, input validation (missing fields, invalid JSON, empty IDs, name length limits), error propagation from the service layer, method-not-allowed responses, and pagination parameters.
**Integration tests** (`internal/integration/`) — Two test files exercising the full stack from HTTP request through router, handler, service, and postgres repository layers. `lifecycle_test.go` has 11 subtests covering the complete certificate lifecycle: team/owner creation, certificate creation, issuer verification, renewal trigger, job verification, agent registration, CSR submission, deployment, and status reporting. `negative_test.go` has 14 subtests covering error paths plus 19 M11b endpoint tests: nonexistent resource lookups (404s), invalid request bodies (malformed JSON, missing required fields), invalid CSR submission, heartbeat for nonexistent agents, wrong HTTP methods on list endpoints, empty list responses, renewal on nonexistent certificates, expired certificate lifecycle, and team/owner/agent group CRUD validation (create with name validation, get not found, list empty, delete, method not allowed). Both use a shared `setupTestServer()` that builds a fully-wired server with real postgres repositories and the Local CA issuer connector.
**Integration tests** (`internal/integration/`) — Two test files exercising the full stack from HTTP request through router, handler, service, and postgres repository layers. `lifecycle_test.go` has 11 subtests covering the complete certificate lifecycle: team/owner creation, certificate creation, issuer verification, renewal trigger, job verification, agent registration, CSR submission, deployment, and status reporting. `negative_test.go` has 14 subtests covering error paths, 19 M11b endpoint tests, and 4 revocation endpoint tests: nonexistent resource lookups (404s), invalid request bodies (malformed JSON, missing required fields), invalid CSR submission, heartbeat for nonexistent agents, wrong HTTP methods on list endpoints, empty list responses, renewal on nonexistent certificates, expired certificate lifecycle, team/owner/agent group CRUD validation, revocation success, already-revoked rejection, not-found revocation, and CRL retrieval. Both use a shared `setupTestServer()` that builds a fully-wired server with real postgres repositories and the Local CA issuer connector.
**Frontend tests** (`web/src/api/client.test.ts`, `web/src/api/utils.test.ts`) — 53 Vitest tests covering the API client and utility functions. The API client tests mock `globalThis.fetch` and verify all endpoint functions (certificates, agents, jobs, policies, issuers, targets, notifications, audit, health) send correct HTTP methods, URLs, headers, and request bodies. They also test API key management (store/retrieve/clear), auth header propagation, 401 event dispatching, and error handling (server messages, error fields, status text fallback). The utility tests use `vi.useFakeTimers()` for deterministic date testing and cover `formatDate`, `formatDateTime`, `timeAgo`, `daysUntil`, and `expiryColor`. The test environment uses jsdom with `@testing-library/jest-dom` matchers.
@@ -248,8 +248,8 @@ func TestDeleteAgentGroup_Success(t *testing.T) {
h.DeleteAgentGroup(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
if w.Code != http.StatusNoContent {
t.Fatalf("expected status 204, got %d", w.Code)
}
if deletedID != "ag-linux" {
t.Errorf("expected deleted ID 'ag-linux', got '%s'", deletedID)
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -23,6 +24,8 @@ type MockCertificateService struct {
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewalFn func(certID string) error
TriggerDeploymentFn func(certID string, targetID string) error
RevokeCertificateFn func(certID string, reason string) error
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error)
}
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
@@ -81,6 +84,20 @@ func (m *MockCertificateService) TriggerDeployment(certID string, targetID strin
return nil
}
func (m *MockCertificateService) RevokeCertificate(certID string, reason string) error {
if m.RevokeCertificateFn != nil {
return m.RevokeCertificateFn(certID, reason)
}
return nil
}
func (m *MockCertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
if m.GetRevokedCertificatesFn != nil {
return m.GetRevokedCertificatesFn()
}
return nil, nil
}
// Helper function to create context with request ID.
func contextWithRequestID() context.Context {
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
@@ -708,3 +725,320 @@ func TestListCertificates_PerPageExceedsMax(t *testing.T) {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// === Revocation Handler Tests ===
func TestRevokeCertificate_Handler_Success(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
if certID != "mc-prod-001" {
t.Errorf("expected certID mc-prod-001, got %s", certID)
}
if reason != "keyCompromise" {
t.Errorf("expected reason keyCompromise, got %s", reason)
}
return nil
},
}
handler := NewCertificateHandler(mock)
body := `{"reason":"keyCompromise"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]string
json.NewDecoder(w.Body).Decode(&resp)
if resp["status"] != "revoked" {
t.Errorf("expected status 'revoked', got %s", resp["status"])
}
}
func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
// Empty reason is OK — service defaults to "unspecified"
return nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
return fmt.Errorf("certificate is already revoked")
},
}
handler := NewCertificateHandler(mock)
body := `{"reason":"keyCompromise"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
return fmt.Errorf("failed to fetch certificate: not found")
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/nonexistent/revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
return fmt.Errorf("invalid revocation reason: badReason")
},
}
handler := NewCertificateHandler(mock)
body := `{"reason":"badReason"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestRevokeCertificate_Handler_InvalidBody(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", bytes.NewBufferString("{invalid json"))
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestRevokeCertificate_Handler_MethodNotAllowed(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001/revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
}
func TestRevokeCertificate_Handler_EmptyID(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates//revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
return fmt.Errorf("cannot revoke archived certificate")
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-archived/revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error {
return fmt.Errorf("database connection lost")
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/revoke", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.RevokeCertificate(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// === CRL Handler Tests ===
func TestGetCRL_Success(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
return []*domain.CertificateRevocation{
{
ID: "rev-1",
CertificateID: "cert-1",
SerialNumber: "ABC123",
Reason: "keyCompromise",
RevokedAt: time.Date(2026, 3, 20, 10, 0, 0, 0, time.UTC),
},
{
ID: "rev-2",
CertificateID: "cert-2",
SerialNumber: "DEF456",
Reason: "superseded",
RevokedAt: time.Date(2026, 3, 21, 14, 30, 0, 0, time.UTC),
},
}, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/crl", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCRL(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]interface{}
json.NewDecoder(w.Body).Decode(&resp)
if resp["version"] != float64(1) {
t.Errorf("expected version 1, got %v", resp["version"])
}
if resp["total"] != float64(2) {
t.Errorf("expected total 2, got %v", resp["total"])
}
entries, ok := resp["entries"].([]interface{})
if !ok {
t.Fatal("expected entries to be an array")
}
if len(entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(entries))
}
entry1 := entries[0].(map[string]interface{})
if entry1["serial_number"] != "ABC123" {
t.Errorf("expected serial ABC123, got %v", entry1["serial_number"])
}
if entry1["revocation_reason"] != "keyCompromise" {
t.Errorf("expected reason keyCompromise, got %v", entry1["revocation_reason"])
}
}
func TestGetCRL_Empty(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
return nil, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/crl", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCRL(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]interface{}
json.NewDecoder(w.Body).Decode(&resp)
if resp["total"] != float64(0) {
t.Errorf("expected total 0, got %v", resp["total"])
}
}
func TestGetCRL_ServiceError(t *testing.T) {
mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
return nil, fmt.Errorf("revocation repository not configured")
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/crl", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCRL(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
func TestGetCRL_MethodNotAllowed(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/crl", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCRL(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
}
+94
View File
@@ -5,6 +5,7 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/domain"
@@ -20,6 +21,8 @@ type CertificateService interface {
GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(certID string) error
TriggerDeployment(certID string, targetID string) error
RevokeCertificate(certID string, reason string) error
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
}
// CertificateHandler handles HTTP requests for certificate operations.
@@ -350,3 +353,94 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req
JSON(w, http.StatusAccepted, response)
}
// RevokeCertificate revokes a certificate with an optional reason code.
// POST /api/v1/certificates/{id}/revoke
func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
requestID := middleware.GetRequestID(r.Context())
// Extract certificate ID from path /api/v1/certificates/{id}/revoke
path := strings.TrimPrefix(r.URL.Path, "/api/v1/certificates/")
parts := strings.Split(path, "/")
if len(parts) < 2 || parts[0] == "" {
ErrorWithRequestID(w, http.StatusBadRequest, "Certificate ID is required", requestID)
return
}
certID := parts[0]
// Parse optional reason from request body
var req struct {
Reason string `json:"reason"`
}
if r.Body != nil && r.Header.Get("Content-Type") == "application/json" {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, "Invalid request body", requestID)
return
}
}
if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil {
// Distinguish between client errors and server errors
errMsg := err.Error()
if strings.Contains(errMsg, "already revoked") ||
strings.Contains(errMsg, "cannot revoke") ||
strings.Contains(errMsg, "invalid revocation reason") {
ErrorWithRequestID(w, http.StatusBadRequest, errMsg, requestID)
return
}
if strings.Contains(errMsg, "not found") || strings.Contains(errMsg, "failed to fetch") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to revoke certificate", requestID)
return
}
JSON(w, http.StatusOK, map[string]string{"status": "revoked"})
}
// GetCRL returns the Certificate Revocation List as structured JSON.
// GET /api/v1/crl
// Note: DER-encoded X.509 CRL generation (requiring CA key access) is planned for M15b
// alongside the embedded OCSP responder. This endpoint provides the same data in JSON format.
func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
requestID := middleware.GetRequestID(r.Context())
revocations, err := h.svc.GetRevokedCertificates()
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
return
}
type CRLEntry struct {
SerialNumber string `json:"serial_number"`
RevocationDate string `json:"revocation_date"`
RevocationReason string `json:"revocation_reason"`
}
entries := make([]CRLEntry, 0, len(revocations))
for _, rev := range revocations {
entries = append(entries, CRLEntry{
SerialNumber: rev.SerialNumber,
RevocationDate: rev.RevokedAt.Format("2006-01-02T15:04:05Z"),
RevocationReason: rev.Reason,
})
}
JSON(w, http.StatusOK, map[string]interface{}{
"version": 1,
"entries": entries,
"total": len(entries),
"generated_at": time.Now().UTC().Format("2006-01-02T15:04:05Z"),
})
}
@@ -2,14 +2,12 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/domain"
)
@@ -551,8 +549,3 @@ func TestDeleteOwner_MethodNotAllowed(t *testing.T) {
t.Fatalf("expected status 405, got %d", w.Code)
}
}
// contextWithRequestID returns a context with a test request ID for use in tests.
func contextWithRequestID() context.Context {
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
}
+5 -5
View File
@@ -2,14 +2,12 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/domain"
)
@@ -133,7 +131,8 @@ func TestListTeams_WithQueryParams(t *testing.T) {
}
}
// TestListTeams_PerPageMaxLimit tests that per_page is capped at 500.
// TestListTeams_PerPageMaxLimit tests that per_page values exceeding 500 are rejected
// and fall back to the default of 50 (the handler ignores invalid per_page values).
func TestListTeams_PerPageMaxLimit(t *testing.T) {
var capturedPerPage int
mock := &MockTeamService{
@@ -150,8 +149,9 @@ func TestListTeams_PerPageMaxLimit(t *testing.T) {
handler.ListTeams(w, req)
if capturedPerPage != 500 {
t.Errorf("expected per_page capped at 500, got %d", capturedPerPage)
// Handler rejects per_page > 500 and falls back to default (50)
if capturedPerPage != 50 {
t.Errorf("expected per_page to fall back to default 50 for values > 500, got %d", capturedPerPage)
}
}
+4
View File
@@ -88,6 +88,10 @@ func (r *Router) RegisterHandlers(
r.Register("GET /api/v1/certificates/{id}/versions", http.HandlerFunc(certificates.GetCertificateVersions))
r.Register("POST /api/v1/certificates/{id}/renew", http.HandlerFunc(certificates.TriggerRenewal))
r.Register("POST /api/v1/certificates/{id}/deploy", http.HandlerFunc(certificates.TriggerDeployment))
r.Register("POST /api/v1/certificates/{id}/revoke", http.HandlerFunc(certificates.RevokeCertificate))
// CRL endpoint: /api/v1/crl
r.Register("GET /api/v1/crl", http.HandlerFunc(certificates.GetCRL))
// Issuers routes: /api/v1/issuers
r.Register("GET /api/v1/issuers", http.HandlerFunc(issuers.ListIssuers))
+2
View File
@@ -22,6 +22,8 @@ type ManagedCertificate struct {
Tags map[string]string `json:"tags"`
LastRenewalAt *time.Time `json:"last_renewal_at,omitempty"`
LastDeploymentAt *time.Time `json:"last_deployment_at,omitempty"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
RevocationReason string `json:"revocation_reason,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
+1
View File
@@ -28,6 +28,7 @@ const (
NotificationTypeDeploymentSuccess NotificationType = "DeploymentSuccess"
NotificationTypeDeploymentFailure NotificationType = "DeploymentFailure"
NotificationTypePolicyViolation NotificationType = "PolicyViolation"
NotificationTypeRevocation NotificationType = "Revocation"
)
// NotificationChannel represents the communication medium for a notification.
+58
View File
@@ -0,0 +1,58 @@
package domain
import "time"
// RevocationReason represents the reason for revoking a certificate.
// Values align with RFC 5280 Section 5.3.1 CRL reason codes.
type RevocationReason string
const (
RevocationReasonUnspecified RevocationReason = "unspecified"
RevocationReasonKeyCompromise RevocationReason = "keyCompromise"
RevocationReasonCACompromise RevocationReason = "caCompromise"
RevocationReasonAffiliationChanged RevocationReason = "affiliationChanged"
RevocationReasonSuperseded RevocationReason = "superseded"
RevocationReasonCessationOfOperation RevocationReason = "cessationOfOperation"
RevocationReasonCertificateHold RevocationReason = "certificateHold"
RevocationReasonPrivilegeWithdrawn RevocationReason = "privilegeWithdrawn"
)
// ValidRevocationReasons contains all valid revocation reason strings.
var ValidRevocationReasons = map[RevocationReason]int{
RevocationReasonUnspecified: 0,
RevocationReasonKeyCompromise: 1,
RevocationReasonCACompromise: 2,
RevocationReasonAffiliationChanged: 3,
RevocationReasonSuperseded: 4,
RevocationReasonCessationOfOperation: 5,
RevocationReasonCertificateHold: 6,
RevocationReasonPrivilegeWithdrawn: 9,
}
// IsValidRevocationReason checks whether a reason string is a valid RFC 5280 reason code.
func IsValidRevocationReason(reason string) bool {
_, ok := ValidRevocationReasons[RevocationReason(reason)]
return ok
}
// CRLReasonCode returns the RFC 5280 integer reason code for a revocation reason.
func CRLReasonCode(reason RevocationReason) int {
if code, ok := ValidRevocationReasons[reason]; ok {
return code
}
return 0 // unspecified
}
// CertificateRevocation records the revocation of a specific certificate version.
// Used as the authoritative source for CRL generation.
type CertificateRevocation struct {
ID string `json:"id"`
CertificateID string `json:"certificate_id"`
SerialNumber string `json:"serial_number"`
Reason string `json:"reason"`
RevokedBy string `json:"revoked_by"`
RevokedAt time.Time `json:"revoked_at"`
IssuerID string `json:"issuer_id"`
IssuerNotified bool `json:"issuer_notified"`
CreatedAt time.Time `json:"created_at"`
}
+57
View File
@@ -0,0 +1,57 @@
package domain
import "testing"
func TestIsValidRevocationReason(t *testing.T) {
tests := []struct {
name string
reason string
want bool
}{
{"unspecified", "unspecified", true},
{"keyCompromise", "keyCompromise", true},
{"caCompromise", "caCompromise", true},
{"affiliationChanged", "affiliationChanged", true},
{"superseded", "superseded", true},
{"cessationOfOperation", "cessationOfOperation", true},
{"certificateHold", "certificateHold", true},
{"privilegeWithdrawn", "privilegeWithdrawn", true},
{"empty string", "", false},
{"random string", "notAValidReason", false},
{"partial match", "key", false},
{"case sensitive", "KeyCompromise", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidRevocationReason(tt.reason); got != tt.want {
t.Errorf("IsValidRevocationReason(%q) = %v, want %v", tt.reason, got, tt.want)
}
})
}
}
func TestCRLReasonCode(t *testing.T) {
tests := []struct {
reason RevocationReason
want int
}{
{RevocationReasonUnspecified, 0},
{RevocationReasonKeyCompromise, 1},
{RevocationReasonCACompromise, 2},
{RevocationReasonAffiliationChanged, 3},
{RevocationReasonSuperseded, 4},
{RevocationReasonCessationOfOperation, 5},
{RevocationReasonCertificateHold, 6},
{RevocationReasonPrivilegeWithdrawn, 9},
{RevocationReason("unknown"), 0}, // falls back to unspecified
}
for _, tt := range tests {
t.Run(string(tt.reason), func(t *testing.T) {
if got := CRLReasonCode(tt.reason); got != tt.want {
t.Errorf("CRLReasonCode(%q) = %d, want %d", tt.reason, got, tt.want)
}
})
}
}
+57
View File
@@ -545,6 +545,14 @@ func (m *mockCertificateRepository) GetExpiringCertificates(ctx context.Context,
return expiring, nil
}
func (m *mockCertificateRepository) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
versions := m.versions[certID]
if len(versions) == 0 {
return nil, fmt.Errorf("no versions found")
}
return versions[len(versions)-1], nil
}
type mockJobRepository struct {
jobs map[string]*domain.Job
}
@@ -1048,3 +1056,52 @@ func (m *mockAgentGroupService) DeleteAgentGroup(id string) error {
func (m *mockAgentGroupService) ListMembers(id string) ([]domain.Agent, int64, error) {
return []domain.Agent{}, 0, nil
}
// mockRevocationRepository is a test implementation of RevocationRepository for integration tests.
type mockRevocationRepository struct {
revocations []*domain.CertificateRevocation
}
func newMockRevocationRepository() *mockRevocationRepository {
return &mockRevocationRepository{
revocations: make([]*domain.CertificateRevocation, 0),
}
}
func (m *mockRevocationRepository) Create(ctx context.Context, revocation *domain.CertificateRevocation) error {
m.revocations = append(m.revocations, revocation)
return nil
}
func (m *mockRevocationRepository) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
for _, r := range m.revocations {
if r.SerialNumber == serial {
return r, nil
}
}
return nil, fmt.Errorf("revocation not found")
}
func (m *mockRevocationRepository) ListAll(ctx context.Context) ([]*domain.CertificateRevocation, error) {
return m.revocations, nil
}
func (m *mockRevocationRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.CertificateRevocation, error) {
var result []*domain.CertificateRevocation
for _, r := range m.revocations {
if r.CertificateID == certID {
result = append(result, r)
}
}
return result, nil
}
func (m *mockRevocationRepository) MarkIssuerNotified(ctx context.Context, id string) error {
for _, r := range m.revocations {
if r.ID == id {
r.IssuerNotified = true
return nil
}
}
return fmt.Errorf("revocation not found")
}
+119
View File
@@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -39,10 +40,17 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
"iss-local": service.NewIssuerConnectorAdapter(localCA),
}
revocationRepo := newMockRevocationRepository()
auditService := service.NewAuditService(auditRepo)
policyService := service.NewPolicyService(policyRepo, auditService)
certificateService := service.NewCertificateService(certRepo, policyService, auditService)
notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier))
// Wire revocation dependencies
certificateService.SetRevocationRepo(revocationRepo)
certificateService.SetNotificationService(notificationService)
certificateService.SetIssuerRegistry(issuerRegistry)
renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server")
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
@@ -671,3 +679,114 @@ func TestM11bEndpoints(t *testing.T) {
})
})
}
// TestRevocationEndpoints exercises the revocation API endpoints through a full integration stack.
func TestRevocationEndpoints(t *testing.T) {
server, certRepo, _, _ := setupTestServer(t)
// Create a test certificate with a version
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "mc-revoke-test",
Name: "Revocation Test Cert",
CommonName: "revoke-test.example.com",
SANs: []string{},
Environment: "test",
OwnerID: "owner-test",
TeamID: "team-test",
IssuerID: "iss-local",
RenewalPolicyID: "policy-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(0, 6, 0),
Tags: map[string]string{},
CreatedAt: now,
UpdatedAt: now,
}
certRepo.certs["mc-revoke-test"] = cert
certRepo.versions["mc-revoke-test"] = []*domain.CertificateVersion{
{
ID: "cv-revoke-test",
CertificateID: "mc-revoke-test",
SerialNumber: "REVOKE-SERIAL-001",
NotBefore: now,
NotAfter: now.AddDate(1, 0, 0),
CreatedAt: now,
},
}
t.Run("RevokeCertificate_Success", func(t *testing.T) {
body := bytes.NewBufferString(`{"reason":"keyCompromise"}`)
resp, err := http.Post(server.URL+"/api/v1/certificates/mc-revoke-test/revoke", "application/json", body)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["status"] != "revoked" {
t.Errorf("expected status 'revoked', got %s", result["status"])
}
// Verify certificate status updated
if cert.Status != domain.CertificateStatusRevoked {
t.Errorf("expected Revoked status, got %s", cert.Status)
}
})
t.Run("RevokeCertificate_AlreadyRevoked", func(t *testing.T) {
body := bytes.NewBufferString(`{"reason":"superseded"}`)
resp, err := http.Post(server.URL+"/api/v1/certificates/mc-revoke-test/revoke", "application/json", body)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected 400 for already revoked, got %d", resp.StatusCode)
}
})
t.Run("RevokeCertificate_NotFound", func(t *testing.T) {
resp, err := http.Post(server.URL+"/api/v1/certificates/mc-nonexistent/revoke", "application/json", strings.NewReader("{}"))
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected 404, got %d", resp.StatusCode)
}
})
t.Run("GetCRL_Success", func(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/crl")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var crl map[string]interface{}
json.NewDecoder(resp.Body).Decode(&crl)
if crl["version"] != float64(1) {
t.Errorf("expected CRL version 1, got %v", crl["version"])
}
// Should have at least 1 entry from the revocation above
total, _ := crl["total"].(float64)
if total < 1 {
t.Errorf("expected at least 1 CRL entry, got %v", total)
}
})
}
+16
View File
@@ -25,6 +25,22 @@ type CertificateRepository interface {
CreateVersion(ctx context.Context, version *domain.CertificateVersion) error
// GetExpiringCertificates returns certificates expiring before the given time.
GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error)
// GetLatestVersion returns the most recent certificate version for a certificate.
GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error)
}
// RevocationRepository defines operations for managing certificate revocations.
type RevocationRepository interface {
// Create records a new certificate revocation.
Create(ctx context.Context, revocation *domain.CertificateRevocation) error
// GetBySerial retrieves a revocation by serial number.
GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error)
// ListAll returns all revocations, ordered by revocation time (for CRL generation).
ListAll(ctx context.Context) ([]*domain.CertificateRevocation, error)
// ListByCertificate returns all revocations for a certificate.
ListByCertificate(ctx context.Context, certID string) ([]*domain.CertificateRevocation, error)
// MarkIssuerNotified updates the issuer_notified flag for a revocation.
MarkIssuerNotified(ctx context.Context, id string) error
}
// IssuerRepository defines operations for managing certificate issuers.
+50 -10
View File
@@ -85,7 +85,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
offset := (filter.Page - 1) * filter.PerPage
query := fmt.Sprintf(`
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
%s
ORDER BY created_at DESC
@@ -120,7 +120,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
WHERE id = $1
`, id)
@@ -152,16 +152,23 @@ func (r *CertificateRepository) Create(ctx context.Context, cert *domain.Managed
profileID = &cert.CertificateProfileID
}
var revocationReason *string
if cert.RevocationReason != "" {
revocationReason = &cert.RevocationReason
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO managed_certificates (
id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
RETURNING id
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID,
cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt,
cert.RevokedAt, revocationReason,
cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
@@ -182,6 +189,11 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed
profileID = &cert.CertificateProfileID
}
var revocationReason *string
if cert.RevocationReason != "" {
revocationReason = &cert.RevocationReason
}
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET
name = $1,
@@ -197,11 +209,14 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed
tags = $11,
last_renewal_at = $12,
last_deployment_at = $13,
updated_at = $14
WHERE id = $15
revoked_at = $14,
revocation_reason = $15,
updated_at = $16
WHERE id = $17
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt,
cert.RevokedAt, revocationReason, cert.UpdatedAt, cert.ID)
if err != nil {
return fmt.Errorf("failed to update certificate: %w", err)
@@ -299,7 +314,7 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
WHERE expires_at < $1 AND status != $2
ORDER BY expires_at ASC
@@ -326,6 +341,26 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
return certs, nil
}
// GetLatestVersion returns the most recent certificate version for a certificate.
func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
var v domain.CertificateVersion
err := r.db.QueryRowContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
LIMIT 1
`, certID).Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to get latest certificate version: %w", err)
}
return &v, nil
}
// scanCertificate scans a certificate from a row or rows
func scanCertificate(scanner interface {
Scan(...interface{}) error
@@ -334,12 +369,14 @@ func scanCertificate(scanner interface {
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := scanner.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan certificate: %w", err)
@@ -349,6 +386,9 @@ func scanCertificate(scanner interface {
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
+130
View File
@@ -0,0 +1,130 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/shankar0123/certctl/internal/domain"
)
// RevocationRepository implements repository.RevocationRepository using PostgreSQL.
type RevocationRepository struct {
db *sql.DB
}
// NewRevocationRepository creates a new RevocationRepository.
func NewRevocationRepository(db *sql.DB) *RevocationRepository {
return &RevocationRepository{db: db}
}
// Create records a new certificate revocation.
func (r *RevocationRepository) Create(ctx context.Context, revocation *domain.CertificateRevocation) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO certificate_revocations (
id, certificate_id, serial_number, reason, revoked_by, revoked_at,
issuer_id, issuer_notified, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (serial_number) DO NOTHING
`, revocation.ID, revocation.CertificateID, revocation.SerialNumber,
revocation.Reason, revocation.RevokedBy, revocation.RevokedAt,
revocation.IssuerID, revocation.IssuerNotified, revocation.CreatedAt)
if err != nil {
return fmt.Errorf("failed to create revocation record: %w", err)
}
return nil
}
// GetBySerial retrieves a revocation by serial number.
func (r *RevocationRepository) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
var rev domain.CertificateRevocation
err := r.db.QueryRowContext(ctx, `
SELECT id, certificate_id, serial_number, reason, revoked_by, revoked_at,
issuer_id, issuer_notified, created_at
FROM certificate_revocations
WHERE serial_number = $1
`, serial).Scan(&rev.ID, &rev.CertificateID, &rev.SerialNumber,
&rev.Reason, &rev.RevokedBy, &rev.RevokedAt,
&rev.IssuerID, &rev.IssuerNotified, &rev.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to get revocation by serial: %w", err)
}
return &rev, nil
}
// ListAll returns all revocations ordered by revocation time (for CRL generation).
func (r *RevocationRepository) ListAll(ctx context.Context) ([]*domain.CertificateRevocation, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, reason, revoked_by, revoked_at,
issuer_id, issuer_notified, created_at
FROM certificate_revocations
ORDER BY revoked_at ASC
`)
if err != nil {
return nil, fmt.Errorf("failed to list revocations: %w", err)
}
defer rows.Close()
return scanRevocations(rows)
}
// ListByCertificate returns all revocations for a certificate.
func (r *RevocationRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.CertificateRevocation, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, reason, revoked_by, revoked_at,
issuer_id, issuer_notified, created_at
FROM certificate_revocations
WHERE certificate_id = $1
ORDER BY revoked_at ASC
`, certID)
if err != nil {
return nil, fmt.Errorf("failed to list revocations by certificate: %w", err)
}
defer rows.Close()
return scanRevocations(rows)
}
// MarkIssuerNotified updates the issuer_notified flag for a revocation.
func (r *RevocationRepository) MarkIssuerNotified(ctx context.Context, id string) error {
result, err := r.db.ExecContext(ctx, `
UPDATE certificate_revocations SET issuer_notified = TRUE WHERE id = $1
`, id)
if err != nil {
return fmt.Errorf("failed to mark issuer notified: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("revocation not found")
}
return nil
}
func scanRevocations(rows *sql.Rows) ([]*domain.CertificateRevocation, error) {
var revocations []*domain.CertificateRevocation
for rows.Next() {
var rev domain.CertificateRevocation
if err := rows.Scan(&rev.ID, &rev.CertificateID, &rev.SerialNumber,
&rev.Reason, &rev.RevokedBy, &rev.RevokedAt,
&rev.IssuerID, &rev.IssuerNotified, &rev.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan revocation: %w", err)
}
revocations = append(revocations, &rev)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating revocation rows: %w", err)
}
return revocations, nil
}
+141 -3
View File
@@ -12,9 +12,12 @@ import (
// CertificateService provides business logic for certificate management.
type CertificateService struct {
certRepo repository.CertificateRepository
policyService *PolicyService
auditService *AuditService
certRepo repository.CertificateRepository
revocationRepo repository.RevocationRepository
policyService *PolicyService
auditService *AuditService
notificationSvc *NotificationService
issuerRegistry map[string]IssuerConnector
}
// NewCertificateService creates a new certificate service.
@@ -30,6 +33,21 @@ func NewCertificateService(
}
}
// SetRevocationRepo sets the revocation repository (called after construction to avoid init order issues).
func (s *CertificateService) SetRevocationRepo(repo repository.RevocationRepository) {
s.revocationRepo = repo
}
// SetNotificationService sets the notification service for revocation alerts.
func (s *CertificateService) SetNotificationService(svc *NotificationService) {
s.notificationSvc = svc
}
// SetIssuerRegistry sets the issuer registry for issuer-level revocation.
func (s *CertificateService) SetIssuerRegistry(registry map[string]IssuerConnector) {
s.issuerRegistry = registry
}
// List returns a paginated list of certificates matching the filter.
func (s *CertificateService) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
certs, total, err := s.certRepo.List(ctx, filter)
@@ -333,3 +351,123 @@ func (s *CertificateService) TriggerRenewal(certID string) error {
func (s *CertificateService) TriggerDeployment(certID string, targetID string) error {
return s.TriggerDeploymentWithActor(context.Background(), certID, "api")
}
// RevokeCertificate revokes a certificate with the given reason.
// Steps:
// 1. Validate the certificate exists and is revocable
// 2. Get the latest certificate version (for serial number)
// 3. Update certificate status to Revoked
// 4. Record revocation in certificate_revocations table
// 5. Notify the issuer connector (best-effort)
// 6. Record audit event
// 7. Send revocation notification
func (s *CertificateService) RevokeCertificate(certID string, reason string) error {
return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api")
}
// RevokeCertificateWithActor performs revocation with actor tracking.
func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
// 1. Validate certificate exists and is revocable
cert, err := s.certRepo.Get(ctx, certID)
if err != nil {
return fmt.Errorf("failed to fetch certificate: %w", err)
}
if cert.Status == domain.CertificateStatusRevoked {
return fmt.Errorf("certificate is already revoked")
}
if cert.Status == domain.CertificateStatusArchived {
return fmt.Errorf("cannot revoke archived certificate")
}
// Validate reason code
if reason == "" {
reason = string(domain.RevocationReasonUnspecified)
}
if !domain.IsValidRevocationReason(reason) {
return fmt.Errorf("invalid revocation reason: %s", reason)
}
// 2. Get latest certificate version for serial number
version, err := s.certRepo.GetLatestVersion(ctx, certID)
if err != nil {
return fmt.Errorf("failed to get certificate version: %w", err)
}
// 3. Update certificate status to Revoked
now := time.Now()
cert.Status = domain.CertificateStatusRevoked
cert.RevokedAt = &now
cert.RevocationReason = reason
cert.UpdatedAt = now
if err := s.certRepo.Update(ctx, cert); err != nil {
return fmt.Errorf("failed to update certificate status: %w", err)
}
// 4. Record revocation in certificate_revocations table (for CRL generation)
if s.revocationRepo != nil {
revocation := &domain.CertificateRevocation{
ID: generateID("rev"),
CertificateID: certID,
SerialNumber: version.SerialNumber,
Reason: reason,
RevokedBy: actor,
RevokedAt: now,
IssuerID: cert.IssuerID,
CreatedAt: now,
}
if err := s.revocationRepo.Create(ctx, revocation); err != nil {
slog.Error("failed to record revocation for CRL", "error", err, "certificate_id", certID)
// Don't fail the overall revocation — the cert status is already updated
}
}
// 5. Notify the issuer connector (best-effort)
if s.issuerRegistry != nil {
if issuerConn, ok := s.issuerRegistry[cert.IssuerID]; ok {
if err := issuerConn.RevokeCertificate(ctx, version.SerialNumber, reason); err != nil {
slog.Error("failed to notify issuer of revocation",
"error", err,
"issuer_id", cert.IssuerID,
"serial", version.SerialNumber)
// Best-effort — don't fail the overall revocation
} else if s.revocationRepo != nil {
// Mark issuer as notified
revocations, _ := s.revocationRepo.ListByCertificate(ctx, certID)
for _, rev := range revocations {
if rev.SerialNumber == version.SerialNumber {
_ = s.revocationRepo.MarkIssuerNotified(ctx, rev.ID)
}
}
}
}
}
// 6. Record audit event
if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
"certificate_revoked", "certificate", certID,
map[string]interface{}{
"common_name": cert.CommonName,
"serial": version.SerialNumber,
"reason": reason,
}); err != nil {
slog.Error("failed to record audit event", "error", err)
}
// 7. Send revocation notification
if s.notificationSvc != nil {
if err := s.notificationSvc.SendRevocationNotification(ctx, cert, reason); err != nil {
slog.Error("failed to send revocation notification", "error", err, "certificate_id", certID)
}
}
return nil
}
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured")
}
return s.revocationRepo.ListAll(context.Background())
}
+12
View File
@@ -57,3 +57,15 @@ func (a *IssuerConnectorAdapter) RenewCertificate(ctx context.Context, commonNam
NotAfter: result.NotAfter,
}, nil
}
// RevokeCertificate delegates to the underlying connector's RevokeCertificate method.
func (a *IssuerConnectorAdapter) RevokeCertificate(ctx context.Context, serial string, reason string) error {
var reasonPtr *string
if reason != "" {
reasonPtr = &reason
}
return a.connector.RevokeCertificate(ctx, issuer.RevocationRequest{
Serial: serial,
Reason: reasonPtr,
})
}
+42 -1
View File
@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
@@ -23,7 +24,7 @@ type mockConnectorLayerIssuer struct {
orderStatus *issuer.OrderStatus
}
func (m *mockConnectorLayerIssuer) ValidateConfig(ctx context.Context, config []byte) error {
func (m *mockConnectorLayerIssuer) ValidateConfig(ctx context.Context, config json.RawMessage) error {
return m.validateErr
}
@@ -327,3 +328,43 @@ func TestIssuerConnectorAdapter_RenewCertificate_RequestTranslation(t *testing.T
t.Errorf("expected CSRPEM %s, got %s", csrPEM, mock.lastRenewReq.CSRPEM)
}
}
// Tests for RevokeCertificate
func TestIssuerConnectorAdapter_RevokeCertificate_Success(t *testing.T) {
ctx := context.Background()
mock := &mockConnectorLayerIssuer{}
adapter := NewIssuerConnectorAdapter(mock)
err := adapter.RevokeCertificate(ctx, "serial-123", "keyCompromise")
if err != nil {
t.Fatalf("RevokeCertificate failed: %v", err)
}
}
func TestIssuerConnectorAdapter_RevokeCertificate_Error(t *testing.T) {
ctx := context.Background()
testErr := errors.New("revocation failed at issuer")
mock := &mockConnectorLayerIssuer{revokeErr: testErr}
adapter := NewIssuerConnectorAdapter(mock)
err := adapter.RevokeCertificate(ctx, "serial-123", "keyCompromise")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, testErr) {
t.Errorf("expected error %v, got %v", testErr, err)
}
}
func TestIssuerConnectorAdapter_RevokeCertificate_EmptyReason(t *testing.T) {
ctx := context.Background()
mock := &mockConnectorLayerIssuer{}
adapter := NewIssuerConnectorAdapter(mock)
// Empty reason should pass nil to the connector
err := adapter.RevokeCertificate(ctx, "serial-456", "")
if err != nil {
t.Fatalf("RevokeCertificate with empty reason failed: %v", err)
}
}
+45
View File
@@ -193,6 +193,51 @@ func (s *NotificationService) SendDeploymentNotification(ctx context.Context, ce
return s.sendNotification(ctx, notif)
}
// SendRevocationNotification sends a certificate revocation notification.
func (s *NotificationService) SendRevocationNotification(ctx context.Context, cert *domain.ManagedCertificate, reason string) error {
body := fmt.Sprintf(
"[REVOKED] The certificate for %s has been revoked.\n\nReason: %s\n\nThis certificate is no longer valid.",
cert.CommonName, reason,
)
notif := &domain.NotificationEvent{
ID: generateID("notif"),
CertificateID: &cert.ID,
Type: domain.NotificationTypeRevocation,
Channel: domain.NotificationChannelWebhook,
Recipient: s.resolveRecipient(ctx, cert.OwnerID),
Message: body,
Status: "pending",
CreatedAt: time.Now(),
}
if err := s.notifRepo.Create(ctx, notif); err != nil {
return fmt.Errorf("failed to create revocation notification: %w", err)
}
// Also send via email channel
emailNotif := &domain.NotificationEvent{
ID: generateID("notif"),
CertificateID: &cert.ID,
Type: domain.NotificationTypeRevocation,
Channel: domain.NotificationChannelEmail,
Recipient: s.resolveRecipient(ctx, cert.OwnerID),
Message: body,
Status: "pending",
CreatedAt: time.Now(),
}
if err := s.notifRepo.Create(ctx, emailNotif); err != nil {
slog.Error("failed to create email revocation notification", "error", err)
}
// Attempt immediate send for both
if err := s.sendNotification(ctx, notif); err != nil {
slog.Error("failed to send webhook revocation notification", "error", err)
}
return s.sendNotification(ctx, emailNotif)
}
// ProcessPendingNotifications sends all pending notifications in batch.
func (s *NotificationService) ProcessPendingNotifications(ctx context.Context) error {
filter := &repository.NotificationFilter{
+2
View File
@@ -37,6 +37,8 @@ type IssuerConnector interface {
IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error)
// RenewCertificate renews a certificate using the provided CSR PEM.
RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error)
// RevokeCertificate revokes a certificate by serial number with an optional reason.
RevokeCertificate(ctx context.Context, serial string, reason string) error
}
// IssuanceResult holds the result of a certificate issuance or renewal operation.
+410
View File
@@ -0,0 +1,410 @@
package service
import (
"context"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// helper to create a test CertificateService wired for revocation tests
func newRevocationTestService() (*CertificateService, *mockCertRepo, *mockRevocationRepo, *mockAuditRepo) {
certRepo := newMockCertificateRepository()
auditRepo := newMockAuditRepository()
policyRepo := newMockPolicyRepository()
revocationRepo := newMockRevocationRepository()
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
certService := NewCertificateService(certRepo, policyService, auditService)
certService.SetRevocationRepo(revocationRepo)
return certService, certRepo, revocationRepo, auditRepo
}
func TestRevokeCertificate_Success(t *testing.T) {
svc, certRepo, revocationRepo, auditRepo := newRevocationTestService()
// Set up test data
cert := &domain.ManagedCertificate{
ID: "cert-1",
CommonName: "example.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
// Add a certificate version with a serial number
version := &domain.CertificateVersion{
ID: "ver-1",
CertificateID: "cert-1",
SerialNumber: "ABC123",
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
CreatedAt: time.Now(),
}
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
// Revoke
err := svc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify certificate status changed
updated, _ := certRepo.Get(context.Background(), "cert-1")
if updated.Status != domain.CertificateStatusRevoked {
t.Errorf("expected status Revoked, got %s", updated.Status)
}
if updated.RevokedAt == nil {
t.Error("expected RevokedAt to be set")
}
if updated.RevocationReason != "keyCompromise" {
t.Errorf("expected reason keyCompromise, got %s", updated.RevocationReason)
}
// Verify revocation record created
if len(revocationRepo.Revocations) != 1 {
t.Fatalf("expected 1 revocation record, got %d", len(revocationRepo.Revocations))
}
rev := revocationRepo.Revocations[0]
if rev.SerialNumber != "ABC123" {
t.Errorf("expected serial ABC123, got %s", rev.SerialNumber)
}
if rev.Reason != "keyCompromise" {
t.Errorf("expected reason keyCompromise, got %s", rev.Reason)
}
if rev.RevokedBy != "admin" {
t.Errorf("expected revokedBy admin, got %s", rev.RevokedBy)
}
// Verify audit event recorded
if len(auditRepo.Events) == 0 {
t.Error("expected audit event to be recorded")
}
foundRevocationAudit := false
for _, e := range auditRepo.Events {
if e.Action == "certificate_revoked" {
foundRevocationAudit = true
}
}
if !foundRevocationAudit {
t.Error("expected certificate_revoked audit event")
}
}
func TestRevokeCertificate_DefaultReason(t *testing.T) {
svc, certRepo, revocationRepo, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-2",
CommonName: "default-reason.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-2"] = []*domain.CertificateVersion{
{ID: "ver-2", CertificateID: "cert-2", SerialNumber: "DEF456", CreatedAt: time.Now()},
}
// Revoke with empty reason — should default to "unspecified"
err := svc.RevokeCertificateWithActor(context.Background(), "cert-2", "", "api")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
updated, _ := certRepo.Get(context.Background(), "cert-2")
if updated.RevocationReason != "unspecified" {
t.Errorf("expected default reason 'unspecified', got %s", updated.RevocationReason)
}
if len(revocationRepo.Revocations) != 1 {
t.Fatalf("expected 1 revocation, got %d", len(revocationRepo.Revocations))
}
if revocationRepo.Revocations[0].Reason != "unspecified" {
t.Errorf("expected revocation reason 'unspecified', got %s", revocationRepo.Revocations[0].Reason)
}
}
func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-3",
CommonName: "already-revoked.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusRevoked,
RevokedAt: &now,
RevocationReason: "keyCompromise",
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin")
if err == nil {
t.Fatal("expected error for already revoked certificate")
}
if err.Error() != "certificate is already revoked" {
t.Errorf("expected 'already revoked' error, got: %v", err)
}
}
func TestRevokeCertificate_ArchivedCert(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-4",
CommonName: "archived.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusArchived,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-4", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error for archived certificate")
}
if err.Error() != "cannot revoke archived certificate" {
t.Errorf("expected 'cannot revoke archived' error, got: %v", err)
}
}
func TestRevokeCertificate_InvalidReason(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-5",
CommonName: "invalid-reason.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-5", "notAValidReason", "admin")
if err == nil {
t.Fatal("expected error for invalid reason")
}
if err.Error() != "invalid revocation reason: notAValidReason" {
t.Errorf("unexpected error: %v", err)
}
}
func TestRevokeCertificate_NotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
err := svc.RevokeCertificateWithActor(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error for nonexistent certificate")
}
}
func TestRevokeCertificate_NoVersion(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-6",
CommonName: "no-version.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
// No versions added — should fail
err := svc.RevokeCertificateWithActor(context.Background(), "cert-6", "keyCompromise", "admin")
if err == nil {
t.Fatal("expected error when no certificate version exists")
}
}
func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
svc, certRepo, revocationRepo, _ := newRevocationTestService()
// Wire up issuer registry with mock
mockIssuer := &mockIssuerConnector{}
svc.SetIssuerRegistry(map[string]IssuerConnector{
"iss-local": mockIssuer,
})
cert := &domain.ManagedCertificate{
ID: "cert-7",
CommonName: "issuer-notify.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-7"] = []*domain.CertificateVersion{
{ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-7", "cessationOfOperation", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify revocation was recorded and issuer was notified
if len(revocationRepo.Revocations) != 1 {
t.Fatalf("expected 1 revocation, got %d", len(revocationRepo.Revocations))
}
if !revocationRepo.Revocations[0].IssuerNotified {
t.Error("expected issuer to be marked as notified")
}
}
func TestRevokeCertificate_WithNotificationService(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
// Wire up notification service
notifRepo := newMockNotificationRepository()
notifService := NewNotificationService(notifRepo, make(map[string]Notifier))
svc.SetNotificationService(notifService)
cert := &domain.ManagedCertificate{
ID: "cert-8",
CommonName: "with-notify.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
OwnerID: "owner-alice",
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-8"] = []*domain.CertificateVersion{
{ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-8", "keyCompromise", "admin")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Should have created revocation notifications (webhook + email)
if len(notifRepo.Notifications) < 1 {
t.Error("expected at least one revocation notification to be created")
}
foundRevocationNotif := false
for _, n := range notifRepo.Notifications {
if n.Type == domain.NotificationTypeRevocation {
foundRevocationNotif = true
}
}
if !foundRevocationNotif {
t.Error("expected Revocation type notification")
}
}
func TestRevokeCertificate_AllValidReasons(t *testing.T) {
reasons := []string{
"unspecified", "keyCompromise", "caCompromise", "affiliationChanged",
"superseded", "cessationOfOperation", "certificateHold", "privilegeWithdrawn",
}
for _, reason := range reasons {
t.Run(reason, func(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-" + reason,
CommonName: reason + ".com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-"+reason] = []*domain.CertificateVersion{
{ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()},
}
err := svc.RevokeCertificateWithActor(context.Background(), "cert-"+reason, reason, "admin")
if err != nil {
t.Fatalf("expected no error for reason %s, got: %v", reason, err)
}
updated, _ := certRepo.Get(context.Background(), "cert-"+reason)
if updated.Status != domain.CertificateStatusRevoked {
t.Errorf("expected Revoked status, got %s", updated.Status)
}
})
}
}
func TestGetRevokedCertificates_Success(t *testing.T) {
svc, _, revocationRepo, _ := newRevocationTestService()
// Pre-populate revocation records
revocationRepo.Revocations = []*domain.CertificateRevocation{
{ID: "rev-1", CertificateID: "cert-1", SerialNumber: "SER-1", Reason: "keyCompromise", RevokedAt: time.Now()},
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
}
revocations, err := svc.GetRevokedCertificates()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(revocations) != 2 {
t.Errorf("expected 2 revocations, got %d", len(revocations))
}
}
func TestGetRevokedCertificates_Empty(t *testing.T) {
svc, _, _, _ := newRevocationTestService()
revocations, err := svc.GetRevokedCertificates()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if revocations == nil {
// nil is acceptable for empty
} else if len(revocations) != 0 {
t.Errorf("expected 0 revocations, got %d", len(revocations))
}
}
func TestGetRevokedCertificates_NoRepo(t *testing.T) {
certRepo := newMockCertificateRepository()
auditRepo := newMockAuditRepository()
policyRepo := newMockPolicyRepository()
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
svc := NewCertificateService(certRepo, policyService, auditService)
// Do NOT set revocation repo
_, err := svc.GetRevokedCertificates()
if err == nil {
t.Fatal("expected error when revocation repo not configured")
}
}
func TestRevokeCertificate_HandlerInterfaceMethod(t *testing.T) {
svc, certRepo, _, _ := newRevocationTestService()
cert := &domain.ManagedCertificate{
ID: "cert-handler",
CommonName: "handler-test.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 6, 0),
}
certRepo.AddCert(cert)
certRepo.Versions["cert-handler"] = []*domain.CertificateVersion{
{ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()},
}
// Test the handler interface method (no actor param)
err := svc.RevokeCertificate("cert-handler", "superseded")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
updated, _ := certRepo.Get(context.Background(), "cert-handler")
if updated.Status != domain.CertificateStatusRevoked {
t.Errorf("expected Revoked status, got %s", updated.Status)
}
}
+4 -5
View File
@@ -3,11 +3,10 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// mockTeamRepo is a test implementation of TeamRepository
@@ -162,8 +161,8 @@ func TestTeamService_List_RepositoryError(t *testing.T) {
t.Fatalf("expected error, got nil")
}
if !errors.Is(err, errors.New("database error")) {
t.Errorf("expected database error, got %v", err)
if !strings.Contains(err.Error(), "database error") {
t.Errorf("expected error containing 'database error', got %v", err)
}
}
@@ -281,7 +280,7 @@ func TestTeamService_Create(t *testing.T) {
t.Errorf("expected ID to be generated, got empty")
}
if !team.ID[:5] == "team-" {
if !(team.ID[:5] == "team-") {
t.Logf("note: generated ID is %s", team.ID)
}
+72
View File
@@ -103,6 +103,14 @@ func (m *mockCertRepo) GetExpiringCertificates(ctx context.Context, before time.
return expiring, nil
}
func (m *mockCertRepo) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
versions := m.Versions[certID]
if len(versions) == 0 {
return nil, errNotFound
}
return versions[len(versions)-1], nil
}
func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) {
m.Certs[cert.ID] = cert
}
@@ -605,6 +613,13 @@ func (m *mockIssuerConnector) RenewCertificate(ctx context.Context, commonName s
return m.IssueCertificate(ctx, commonName, sans, csrPEM)
}
func (m *mockIssuerConnector) RevokeCertificate(ctx context.Context, serial string, reason string) error {
if m.Err != nil {
return m.Err
}
return nil
}
// Constructor functions for mocks
func newMockCertificateRepository() *mockCertRepo {
@@ -725,6 +740,63 @@ func (m *mockIssuerRepository) AddIssuer(issuer *domain.Issuer) {
m.issuers[issuer.ID] = issuer
}
// mockRevocationRepo is a test implementation of RevocationRepository
type mockRevocationRepo struct {
Revocations []*domain.CertificateRevocation
CreateErr error
ListErr error
}
func (m *mockRevocationRepo) Create(ctx context.Context, revocation *domain.CertificateRevocation) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Revocations = append(m.Revocations, revocation)
return nil
}
func (m *mockRevocationRepo) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
for _, r := range m.Revocations {
if r.SerialNumber == serial {
return r, nil
}
}
return nil, errNotFound
}
func (m *mockRevocationRepo) ListAll(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
return m.Revocations, nil
}
func (m *mockRevocationRepo) ListByCertificate(ctx context.Context, certID string) ([]*domain.CertificateRevocation, error) {
var result []*domain.CertificateRevocation
for _, r := range m.Revocations {
if r.CertificateID == certID {
result = append(result, r)
}
}
return result, nil
}
func (m *mockRevocationRepo) MarkIssuerNotified(ctx context.Context, id string) error {
for _, r := range m.Revocations {
if r.ID == id {
r.IssuerNotified = true
return nil
}
}
return errNotFound
}
func newMockRevocationRepository() *mockRevocationRepo {
return &mockRevocationRepo{
Revocations: make([]*domain.CertificateRevocation, 0),
}
}
// mockNotifier is a simple notifier for testing
type mockNotifier struct {
messages []*mockNotifierMessage
+6
View File
@@ -0,0 +1,6 @@
-- Rollback Migration 000005: Revocation Infrastructure
DROP TABLE IF EXISTS certificate_revocations;
ALTER TABLE managed_certificates DROP COLUMN IF EXISTS revoked_at;
ALTER TABLE managed_certificates DROP COLUMN IF EXISTS revocation_reason;
+33
View File
@@ -0,0 +1,33 @@
-- Migration 000005: Revocation Infrastructure
-- Adds revocation tracking to managed_certificates and a dedicated revocations table for CRL generation.
-- Add revocation columns to managed_certificates
ALTER TABLE managed_certificates ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ;
ALTER TABLE managed_certificates ADD COLUMN IF NOT EXISTS revocation_reason VARCHAR(50);
-- Certificate revocations table for CRL generation
-- Each row represents a revoked certificate version (by serial number).
-- This is the authoritative source for CRL content.
CREATE TABLE IF NOT EXISTS certificate_revocations (
id TEXT PRIMARY KEY,
certificate_id TEXT NOT NULL REFERENCES managed_certificates(id),
serial_number TEXT NOT NULL,
reason VARCHAR(50) NOT NULL DEFAULT 'unspecified',
revoked_by TEXT NOT NULL, -- actor who initiated revocation
revoked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
issuer_id TEXT REFERENCES issuers(id),
issuer_notified BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Index for CRL generation (all revoked certs, ordered by revocation time)
CREATE INDEX IF NOT EXISTS idx_certificate_revocations_revoked_at ON certificate_revocations(revoked_at);
-- Index for looking up revocations by certificate
CREATE INDEX IF NOT EXISTS idx_certificate_revocations_cert_id ON certificate_revocations(certificate_id);
-- Index for looking up revocations by serial (OCSP lookup, future M15b)
CREATE UNIQUE INDEX IF NOT EXISTS idx_certificate_revocations_serial ON certificate_revocations(serial_number);
-- Add revocation notification type
-- (NotificationType is enforced in Go code, not DB constraints, so no ALTER needed)