mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 17:51:29 +00:00
docs: synchronize project documentation with codebase
Implements 3 deferred security tickets (TICKET-003, TICKET-007, TICKET-010) and performs comprehensive documentation audit to eliminate drift between code and docs. Code changes: - TICKET-003: Repository integration tests with testcontainers-go (50+ subtests) - TICKET-007: CertificateService decomposition into RevocationSvc + CAOperationsSvc - TICKET-010: Request body size limits via http.MaxBytesReader middleware - Fix missing slog import in certificate.go after service decomposition Documentation updates: - README: Fix endpoint count (97→93), expand env var reference (15→39 vars) - CLAUDE.md: Fix OpenAPI operation count (85→93), update file locations - architecture.md: Add body size limits section, middleware chain ordering - CONTRIBUTING.md: New contributor guide with architecture conventions, test patterns, middleware ordering, CI thresholds - SECURITY_REMEDIATION.md: Removed from repo (moved to cowork, gitignored) - Test files: Add doc comments to all new test files Documentation that should exist but doesn't yet: - Architecture diagrams (C4 model or similar) - Threat model document - Testing philosophy guide - Disaster recovery runbook - Upgrade guide (migration between versions) - API versioning strategy document Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BodyLimitConfig holds configuration for the body size limit middleware.
|
||||
type BodyLimitConfig struct {
|
||||
MaxBytes int64 // Maximum request body size in bytes; 0 = use default (1MB)
|
||||
}
|
||||
|
||||
// DefaultMaxBodySize is the default maximum request body size (1MB).
|
||||
const DefaultMaxBodySize int64 = 1 * 1024 * 1024
|
||||
|
||||
// NewBodyLimit creates a middleware that limits request body size.
|
||||
// If the body exceeds the configured limit, the server returns 413 Request Entity Too Large.
|
||||
// This prevents clients from sending excessively large payloads that could cause
|
||||
// memory exhaustion or denial of service (CWE-400).
|
||||
func NewBodyLimit(cfg BodyLimitConfig) func(http.Handler) http.Handler {
|
||||
maxBytes := cfg.MaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = DefaultMaxBodySize
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip body limit for requests without bodies
|
||||
if r.Body == nil || r.ContentLength == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the body with MaxBytesReader
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// Tests for the request body size limit middleware (TICKET-010).
|
||||
// Covers under/over/exact limit, nil body, default size, GET requests,
|
||||
// and custom limits.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBodyLimit_UnderLimit(t *testing.T) {
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 1024})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected read error: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}),
|
||||
)
|
||||
|
||||
body := bytes.NewReader([]byte("small body"))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_OverLimit(t *testing.T) {
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 10})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
// MaxBytesReader returns an error when limit exceeded
|
||||
http.Error(w, `{"error":"Request body too large"}`, http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
body := bytes.NewReader([]byte("this body exceeds ten bytes"))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_ExactLimit(t *testing.T) {
|
||||
data := "exactly10!" // 10 bytes
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 10})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, `{"error":"Request body too large"}`, http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(data))
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_NilBody(t *testing.T) {
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 1024})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_DefaultSize(t *testing.T) {
|
||||
// When MaxBytes is 0, should use default (1MB)
|
||||
mw := NewBodyLimit(BodyLimitConfig{MaxBytes: 0})
|
||||
|
||||
called := false
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
body := bytes.NewReader([]byte("test"))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Error("handler was not called")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_GETRequest_NoBody(t *testing.T) {
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 10})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_ContentLengthZero(t *testing.T) {
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: 10})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", nil)
|
||||
req.ContentLength = 0
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit_CustomMaxBytes(t *testing.T) {
|
||||
// Test with 512KB limit
|
||||
const maxSize = 512 * 1024
|
||||
handler := NewBodyLimit(BodyLimitConfig{MaxBytes: maxSize})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, `{"error":"Request body too large"}`, http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", string(rune(len(body))))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
)
|
||||
|
||||
// Create a body just under the limit
|
||||
bodyData := make([]byte, maxSize-1)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyData))
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d for body just under limit", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -203,8 +203,9 @@ type VerificationConfig struct {
|
||||
|
||||
// ServerConfig contains HTTP server configuration.
|
||||
type ServerConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Host string // Server host (default: 127.0.0.1). Set via CERTCTL_SERVER_HOST.
|
||||
Port int // Server port (default: 8080). Set via CERTCTL_SERVER_PORT.
|
||||
MaxBodySize int64 // Maximum request body size in bytes (default: 1MB). Set via CERTCTL_MAX_BODY_SIZE.
|
||||
}
|
||||
|
||||
// DatabaseConfig contains database connection configuration.
|
||||
@@ -301,8 +302,9 @@ type CORSConfig struct {
|
||||
func Load() (*Config, error) {
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Host: getEnv("CERTCTL_SERVER_HOST", "127.0.0.1"),
|
||||
Port: getEnvInt("CERTCTL_SERVER_PORT", 8080),
|
||||
Host: getEnv("CERTCTL_SERVER_HOST", "127.0.0.1"),
|
||||
Port: getEnvInt("CERTCTL_SERVER_PORT", 8080),
|
||||
MaxBodySize: getEnvInt64("CERTCTL_MAX_BODY_SIZE", 1024*1024), // 1MB default
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
URL: getEnv("CERTCTL_DATABASE_URL", "postgres://localhost/certctl"),
|
||||
@@ -471,6 +473,18 @@ func getEnvInt(key string, defaultValue int) int {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvInt64 reads an int64 environment variable with the given key and default value.
|
||||
func getEnvInt64(key string, defaultValue int64) int64 {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
intVal, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvDuration reads a time.Duration environment variable.
|
||||
// The value should be a valid Go duration string (e.g., "1h", "30s", "5m").
|
||||
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,196 @@
|
||||
// Package postgres_test contains integration tests for PostgreSQL repository
|
||||
// implementations using testcontainers-go. Tests spin up a real PostgreSQL 16
|
||||
// container and use schema-per-test isolation for parallel safety.
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
// testDB holds a shared database connection for a test suite.
|
||||
// Each test gets its own schema (via search_path) for isolation.
|
||||
type testDB struct {
|
||||
db *sql.DB
|
||||
container testcontainers.Container
|
||||
}
|
||||
|
||||
// setupTestDB starts a PostgreSQL container and runs all migrations.
|
||||
// Call this once per test file via TestMain or a sync.Once.
|
||||
func setupTestDB(t *testing.T) *testDB {
|
||||
t.Helper()
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: "postgres:16-alpine",
|
||||
ExposedPorts: []string{"5432/tcp"},
|
||||
Env: map[string]string{
|
||||
"POSTGRES_DB": "certctl_test",
|
||||
"POSTGRES_USER": "certctl",
|
||||
"POSTGRES_PASSWORD": "certctl",
|
||||
},
|
||||
WaitingFor: wait.ForLog("database system is ready to accept connections").WithOccurrence(2),
|
||||
}
|
||||
|
||||
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start postgres container: %v", err)
|
||||
}
|
||||
|
||||
host, err := container.Host(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get container host: %v", err)
|
||||
}
|
||||
|
||||
port, err := container.MappedPort(ctx, "5432")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get mapped port: %v", err)
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf("postgres://certctl:certctl@%s:%s/certctl_test?sslmode=disable", host, port.Port())
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
t.Fatalf("failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
migrationsPath := findMigrationsDir()
|
||||
if err := runMigrations(db, migrationsPath); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
return &testDB{db: db, container: container}
|
||||
}
|
||||
|
||||
// teardown stops the container and closes the connection.
|
||||
func (tdb *testDB) teardown(t *testing.T) {
|
||||
t.Helper()
|
||||
if tdb.db != nil {
|
||||
tdb.db.Close()
|
||||
}
|
||||
if tdb.container != nil {
|
||||
tdb.container.Terminate(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
// freshSchema creates a new PostgreSQL schema for test isolation
|
||||
// and returns a *sql.DB with search_path set to that schema.
|
||||
// Each test gets a unique schema so tests don't interfere with each other.
|
||||
func (tdb *testDB) freshSchema(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
// Create a unique schema name from the test name
|
||||
schemaName := sanitizeSchemaName(t.Name())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create schema
|
||||
_, err := tdb.db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create schema %s: %v", schemaName, err)
|
||||
}
|
||||
|
||||
// Set search_path for this connection to use the new schema
|
||||
_, err = tdb.db.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s, public", schemaName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set search_path: %v", err)
|
||||
}
|
||||
|
||||
// Run migrations in the new schema
|
||||
migrationsPath := findMigrationsDir()
|
||||
if err := runMigrationsWithSearchPath(tdb.db, migrationsPath, schemaName); err != nil {
|
||||
t.Fatalf("failed to run migrations in schema %s: %v", schemaName, err)
|
||||
}
|
||||
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
tdb.db.ExecContext(context.Background(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
|
||||
})
|
||||
|
||||
return tdb.db
|
||||
}
|
||||
|
||||
// sanitizeSchemaName converts a test name to a valid PostgreSQL schema name.
|
||||
func sanitizeSchemaName(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
name = strings.ReplaceAll(name, "-", "_")
|
||||
name = strings.ReplaceAll(name, ".", "_")
|
||||
// Truncate to 63 chars (PG limit)
|
||||
if len(name) > 60 {
|
||||
name = name[:60]
|
||||
}
|
||||
return "test_" + name
|
||||
}
|
||||
|
||||
// findMigrationsDir walks up from the test file to find the migrations/ directory.
|
||||
func findMigrationsDir() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
|
||||
// Walk up to find the project root (where migrations/ lives)
|
||||
for i := 0; i < 10; i++ {
|
||||
candidate := filepath.Join(dir, "migrations")
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
|
||||
// Fallback: try relative from working directory
|
||||
return "../../../../migrations"
|
||||
}
|
||||
|
||||
// runMigrations reads and executes all .up.sql migration files.
|
||||
func runMigrations(db *sql.DB, migrationsPath string) error {
|
||||
files, err := os.ReadDir(migrationsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migrations directory %s: %w", migrationsPath, err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasSuffix(file.Name(), ".up.sql") {
|
||||
content, err := os.ReadFile(filepath.Join(migrationsPath, file.Name()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration %s: %w", file.Name(), err)
|
||||
}
|
||||
if _, err := db.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %s: %w", file.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMigrationsWithSearchPath runs migrations within a specific schema.
|
||||
func runMigrationsWithSearchPath(db *sql.DB, migrationsPath string, schema string) error {
|
||||
// Set search_path before running migrations
|
||||
if _, err := db.Exec(fmt.Sprintf("SET search_path TO %s, public", schema)); err != nil {
|
||||
return fmt.Errorf("failed to set search_path: %w", err)
|
||||
}
|
||||
return runMigrations(db, migrationsPath)
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// CAOperationsSvc provides CA operations: CRL generation and OCSP response signing.
|
||||
// This service handles revocation status queries and certificate lifecycle operations
|
||||
// related to the certificate authority.
|
||||
type CAOperationsSvc struct {
|
||||
revocationRepo repository.RevocationRepository
|
||||
certRepo repository.CertificateRepository
|
||||
profileRepo repository.CertificateProfileRepository
|
||||
issuerRegistry map[string]IssuerConnector
|
||||
}
|
||||
|
||||
// NewCAOperationsSvc creates a new CA operations service.
|
||||
func NewCAOperationsSvc(
|
||||
revocationRepo repository.RevocationRepository,
|
||||
certRepo repository.CertificateRepository,
|
||||
profileRepo repository.CertificateProfileRepository,
|
||||
) *CAOperationsSvc {
|
||||
return &CAOperationsSvc{
|
||||
revocationRepo: revocationRepo,
|
||||
certRepo: certRepo,
|
||||
profileRepo: profileRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// SetIssuerRegistry sets the issuer registry for CRL and OCSP operations.
|
||||
func (s *CAOperationsSvc) SetIssuerRegistry(registry map[string]IssuerConnector) {
|
||||
s.issuerRegistry = registry
|
||||
}
|
||||
|
||||
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
||||
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
|
||||
func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
}
|
||||
if s.issuerRegistry == nil {
|
||||
return nil, fmt.Errorf("issuer registry not configured")
|
||||
}
|
||||
|
||||
issuerConn, ok := s.issuerRegistry[issuerID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
||||
}
|
||||
|
||||
revocations, err := s.revocationRepo.ListAll(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list revocations: %w", err)
|
||||
}
|
||||
|
||||
// Filter to this issuer and convert to CRL entries.
|
||||
// Short-lived certificates (profile TTL < 1 hour) are excluded — expiry is sufficient revocation.
|
||||
var entries []CRLEntry
|
||||
for _, rev := range revocations {
|
||||
if rev.IssuerID != issuerID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check short-lived exemption: look up the cert's profile
|
||||
if s.profileRepo != nil && s.certRepo != nil {
|
||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
||||
if err == nil && cert.CertificateProfileID != "" {
|
||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
||||
if err == nil && profile.IsShortLived() {
|
||||
slog.Debug("skipping short-lived cert from CRL",
|
||||
"certificate_id", rev.CertificateID,
|
||||
"profile_id", cert.CertificateProfileID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse serial number from hex string
|
||||
serial := new(big.Int)
|
||||
serial.SetString(rev.SerialNumber, 16)
|
||||
|
||||
entries = append(entries, CRLEntry{
|
||||
SerialNumber: serial,
|
||||
RevokedAt: rev.RevokedAt,
|
||||
ReasonCode: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)),
|
||||
})
|
||||
}
|
||||
|
||||
return issuerConn.GenerateCRL(context.Background(), entries)
|
||||
}
|
||||
|
||||
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
||||
func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
}
|
||||
if s.issuerRegistry == nil {
|
||||
return nil, fmt.Errorf("issuer registry not configured")
|
||||
}
|
||||
|
||||
issuerConn, ok := s.issuerRegistry[issuerID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
||||
}
|
||||
|
||||
serial := new(big.Int)
|
||||
serial.SetString(serialHex, 16)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Short-lived cert exemption: if the cert's profile has TTL < 1 hour,
|
||||
// always return "good" — expiry is sufficient revocation for short-lived certs.
|
||||
if s.profileRepo != nil && s.certRepo != nil {
|
||||
// Look up cert by serial through revocation table
|
||||
rev, _ := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
if rev != nil {
|
||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
||||
if err == nil && cert.CertificateProfileID != "" {
|
||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
||||
if err == nil && profile.IsShortLived() {
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 0, // good — short-lived exemption
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this serial is revoked
|
||||
rev, err := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
if err != nil {
|
||||
// Not revoked — return "good" status
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 0, // good
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
// Revoked
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 1, // revoked
|
||||
RevokedAt: rev.RevokedAt,
|
||||
RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)),
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// Tests for CAOperationsSvc, the focused sub-service that handles CRL generation
|
||||
// and OCSP response signing extracted from CertificateService (TICKET-007).
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// helper to create a CAOperationsSvc for testing
|
||||
func newCAOperationsSvcTest() (*CAOperationsSvc, *mockRevocationRepo, *mockCertRepo) {
|
||||
revocationRepo := newMockRevocationRepository()
|
||||
certRepo := newMockCertificateRepository()
|
||||
profileRepo := newMockProfileRepository()
|
||||
|
||||
caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo)
|
||||
caSvc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
"iss-local": &mockIssuerConnector{},
|
||||
})
|
||||
|
||||
return caSvc, revocationRepo, certRepo
|
||||
}
|
||||
|
||||
func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) {
|
||||
caSvc, revocationRepo, _ := newCAOperationsSvcTest()
|
||||
|
||||
// Add some revoked certificates to the repo
|
||||
now := time.Now()
|
||||
revocationRepo.Revocations = []*domain.CertificateRevocation{
|
||||
{
|
||||
SerialNumber: "SERIAL-001",
|
||||
CertificateID: "cert-1",
|
||||
IssuerID: "iss-local",
|
||||
Reason: "keyCompromise",
|
||||
RevokedAt: now.Add(-24 * time.Hour),
|
||||
RevokedBy: "admin",
|
||||
},
|
||||
{
|
||||
SerialNumber: "SERIAL-002",
|
||||
CertificateID: "cert-2",
|
||||
IssuerID: "iss-local",
|
||||
Reason: "superseded",
|
||||
RevokedAt: now.Add(-12 * time.Hour),
|
||||
RevokedBy: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
crl, err := caSvc.GenerateDERCRL("iss-local")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if crl == nil {
|
||||
t.Fatal("expected non-nil CRL")
|
||||
}
|
||||
|
||||
if len(crl) == 0 {
|
||||
t.Fatal("expected non-empty CRL")
|
||||
}
|
||||
|
||||
t.Logf("DER CRL generated successfully: %d bytes", len(crl))
|
||||
}
|
||||
|
||||
func TestCAOperationsSvc_GenerateDERCRL_EmptyCRL(t *testing.T) {
|
||||
caSvc, revocationRepo, _ := newCAOperationsSvcTest()
|
||||
|
||||
// No revoked certs for this issuer
|
||||
revocationRepo.Revocations = []*domain.CertificateRevocation{}
|
||||
|
||||
crl, err := caSvc.GenerateDERCRL("iss-local")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if crl == nil {
|
||||
t.Fatal("expected non-nil CRL even when empty")
|
||||
}
|
||||
|
||||
if len(crl) == 0 {
|
||||
t.Fatal("expected non-empty CRL bytes (at least the CRL structure)")
|
||||
}
|
||||
|
||||
t.Logf("Empty DER CRL generated successfully: %d bytes", len(crl))
|
||||
}
|
||||
|
||||
func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) {
|
||||
caSvc, _, certRepo := newCAOperationsSvcTest()
|
||||
|
||||
// Add a non-revoked certificate
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-ocsp-good",
|
||||
CommonName: "good.example.com",
|
||||
IssuerID: "iss-local",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
version := &domain.CertificateVersion{
|
||||
ID: "ver-ocsp-good",
|
||||
CertificateID: "cert-ocsp-good",
|
||||
SerialNumber: "OCSP-GOOD-001",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
|
||||
|
||||
// Request OCSP response for good cert
|
||||
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-GOOD-001")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp) == 0 {
|
||||
t.Fatal("expected non-empty OCSP response for good cert")
|
||||
}
|
||||
|
||||
t.Logf("OCSP response for good cert generated: %d bytes", len(resp))
|
||||
}
|
||||
|
||||
func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) {
|
||||
caSvc, revocationRepo, certRepo := newCAOperationsSvcTest()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Add a revoked certificate
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-ocsp-revoked",
|
||||
CommonName: "revoked.example.com",
|
||||
IssuerID: "iss-local",
|
||||
Status: domain.CertificateStatusRevoked,
|
||||
RevokedAt: &now,
|
||||
RevocationReason: "keyCompromise",
|
||||
ExpiresAt: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
version := &domain.CertificateVersion{
|
||||
ID: "ver-ocsp-revoked",
|
||||
CertificateID: "cert-ocsp-revoked",
|
||||
SerialNumber: "OCSP-REVOKED-001",
|
||||
NotBefore: time.Now().Add(-24 * time.Hour),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
certRepo.Versions["cert-ocsp-revoked"] = []*domain.CertificateVersion{version}
|
||||
|
||||
// Add revocation record
|
||||
revocationRepo.Revocations = []*domain.CertificateRevocation{
|
||||
{
|
||||
SerialNumber: "OCSP-REVOKED-001",
|
||||
CertificateID: "cert-ocsp-revoked",
|
||||
IssuerID: "iss-local",
|
||||
Reason: "keyCompromise",
|
||||
RevokedAt: now.Add(-24 * time.Hour),
|
||||
RevokedBy: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
// Request OCSP response for revoked cert
|
||||
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp) == 0 {
|
||||
t.Fatal("expected non-empty OCSP response for revoked cert")
|
||||
}
|
||||
|
||||
t.Logf("OCSP response for revoked cert generated: %d bytes", len(resp))
|
||||
}
|
||||
+29
-240
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
@@ -13,14 +12,12 @@ import (
|
||||
|
||||
// CertificateService provides business logic for certificate management.
|
||||
type CertificateService struct {
|
||||
certRepo repository.CertificateRepository
|
||||
targetRepo repository.TargetRepository
|
||||
revocationRepo repository.RevocationRepository
|
||||
profileRepo repository.CertificateProfileRepository
|
||||
policyService *PolicyService
|
||||
auditService *AuditService
|
||||
notificationSvc *NotificationService
|
||||
issuerRegistry map[string]IssuerConnector
|
||||
certRepo repository.CertificateRepository
|
||||
targetRepo repository.TargetRepository
|
||||
policyService *PolicyService
|
||||
auditService *AuditService
|
||||
revSvc *RevocationSvc
|
||||
caSvc *CAOperationsSvc
|
||||
}
|
||||
|
||||
// NewCertificateService creates a new certificate service.
|
||||
@@ -36,24 +33,14 @@ func NewCertificateService(
|
||||
}
|
||||
}
|
||||
|
||||
// SetRevocationRepo sets the revocation repository (called after construction to avoid init order issues).
|
||||
func (s *CertificateService) SetRevocationRepo(repo repository.RevocationRepository) {
|
||||
s.revocationRepo = repo
|
||||
// SetRevocationSvc sets the revocation service.
|
||||
func (s *CertificateService) SetRevocationSvc(svc *RevocationSvc) {
|
||||
s.revSvc = svc
|
||||
}
|
||||
|
||||
// SetNotificationService sets the notification service for revocation alerts.
|
||||
func (s *CertificateService) SetNotificationService(svc *NotificationService) {
|
||||
s.notificationSvc = svc
|
||||
}
|
||||
|
||||
// SetIssuerRegistry sets the issuer registry for issuer-level revocation.
|
||||
func (s *CertificateService) SetIssuerRegistry(registry map[string]IssuerConnector) {
|
||||
s.issuerRegistry = registry
|
||||
}
|
||||
|
||||
// SetProfileRepo sets the profile repository for short-lived cert exemption in CRL/OCSP.
|
||||
func (s *CertificateService) SetProfileRepo(repo repository.CertificateProfileRepository) {
|
||||
s.profileRepo = repo
|
||||
// SetCAOperationsSvc sets the CA operations service.
|
||||
func (s *CertificateService) SetCAOperationsSvc(svc *CAOperationsSvc) {
|
||||
s.caSvc = svc
|
||||
}
|
||||
|
||||
// SetTargetRepo sets the target repository for deployment queries.
|
||||
@@ -381,243 +368,45 @@ func (s *CertificateService) TriggerDeployment(certID string, targetID string) e
|
||||
return s.TriggerDeploymentWithActor(context.Background(), certID, "api")
|
||||
}
|
||||
|
||||
// RevokeCertificate revokes a certificate with the given reason.
|
||||
// Steps:
|
||||
// 1. Validate the certificate exists and is revocable
|
||||
// 2. Get the latest certificate version (for serial number)
|
||||
// 3. Update certificate status to Revoked
|
||||
// 4. Record revocation in certificate_revocations table
|
||||
// 5. Notify the issuer connector (best-effort)
|
||||
// 6. Record audit event
|
||||
// 7. Send revocation notification
|
||||
// RevokeCertificate revokes a certificate with the given reason (handler interface method).
|
||||
func (s *CertificateService) RevokeCertificate(certID string, reason string) error {
|
||||
return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api")
|
||||
}
|
||||
|
||||
// RevokeCertificateWithActor performs revocation with actor tracking.
|
||||
// Delegates to RevocationSvc.
|
||||
func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
|
||||
// 1. Validate certificate exists and is revocable
|
||||
cert, err := s.certRepo.Get(ctx, certID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch certificate: %w", err)
|
||||
if s.revSvc == nil {
|
||||
return fmt.Errorf("revocation service not configured")
|
||||
}
|
||||
|
||||
if cert.Status == domain.CertificateStatusRevoked {
|
||||
return fmt.Errorf("certificate is already revoked")
|
||||
}
|
||||
if cert.Status == domain.CertificateStatusArchived {
|
||||
return fmt.Errorf("cannot revoke archived certificate")
|
||||
}
|
||||
|
||||
// Validate reason code
|
||||
if reason == "" {
|
||||
reason = string(domain.RevocationReasonUnspecified)
|
||||
}
|
||||
if !domain.IsValidRevocationReason(reason) {
|
||||
return fmt.Errorf("invalid revocation reason: %s", reason)
|
||||
}
|
||||
|
||||
// 2. Get latest certificate version for serial number
|
||||
version, err := s.certRepo.GetLatestVersion(ctx, certID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get certificate version: %w", err)
|
||||
}
|
||||
|
||||
// 3. Update certificate status to Revoked
|
||||
now := time.Now()
|
||||
cert.Status = domain.CertificateStatusRevoked
|
||||
cert.RevokedAt = &now
|
||||
cert.RevocationReason = reason
|
||||
cert.UpdatedAt = now
|
||||
if err := s.certRepo.Update(ctx, cert); err != nil {
|
||||
return fmt.Errorf("failed to update certificate status: %w", err)
|
||||
}
|
||||
|
||||
// 4. Record revocation in certificate_revocations table (for CRL generation)
|
||||
if s.revocationRepo != nil {
|
||||
revocation := &domain.CertificateRevocation{
|
||||
ID: generateID("rev"),
|
||||
CertificateID: certID,
|
||||
SerialNumber: version.SerialNumber,
|
||||
Reason: reason,
|
||||
RevokedBy: actor,
|
||||
RevokedAt: now,
|
||||
IssuerID: cert.IssuerID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := s.revocationRepo.Create(ctx, revocation); err != nil {
|
||||
slog.Error("failed to record revocation for CRL", "error", err, "certificate_id", certID)
|
||||
// Don't fail the overall revocation — the cert status is already updated
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Notify the issuer connector (best-effort)
|
||||
if s.issuerRegistry != nil {
|
||||
if issuerConn, ok := s.issuerRegistry[cert.IssuerID]; ok {
|
||||
if err := issuerConn.RevokeCertificate(ctx, version.SerialNumber, reason); err != nil {
|
||||
slog.Error("failed to notify issuer of revocation",
|
||||
"error", err,
|
||||
"issuer_id", cert.IssuerID,
|
||||
"serial", version.SerialNumber)
|
||||
// Best-effort — don't fail the overall revocation
|
||||
} else if s.revocationRepo != nil {
|
||||
// Mark issuer as notified
|
||||
revocations, _ := s.revocationRepo.ListByCertificate(ctx, certID)
|
||||
for _, rev := range revocations {
|
||||
if rev.SerialNumber == version.SerialNumber {
|
||||
_ = s.revocationRepo.MarkIssuerNotified(ctx, rev.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Record audit event
|
||||
if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
|
||||
"certificate_revoked", "certificate", certID,
|
||||
map[string]interface{}{
|
||||
"common_name": cert.CommonName,
|
||||
"serial": version.SerialNumber,
|
||||
"reason": reason,
|
||||
}); err != nil {
|
||||
slog.Error("failed to record audit event", "error", err)
|
||||
}
|
||||
|
||||
// 7. Send revocation notification
|
||||
if s.notificationSvc != nil {
|
||||
if err := s.notificationSvc.SendRevocationNotification(ctx, cert, reason); err != nil {
|
||||
slog.Error("failed to send revocation notification", "error", err, "certificate_id", certID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.revSvc.RevokeCertificateWithActor(ctx, certID, reason, actor)
|
||||
}
|
||||
|
||||
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
||||
// Delegates to RevocationSvc.
|
||||
func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
if s.revSvc == nil {
|
||||
return nil, fmt.Errorf("revocation service not configured")
|
||||
}
|
||||
return s.revocationRepo.ListAll(context.Background())
|
||||
return s.revSvc.GetRevokedCertificates()
|
||||
}
|
||||
|
||||
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
||||
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
|
||||
// Delegates to CAOperationsSvc.
|
||||
func (s *CertificateService) GenerateDERCRL(issuerID string) ([]byte, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
if s.caSvc == nil {
|
||||
return nil, fmt.Errorf("CA operations service not configured")
|
||||
}
|
||||
if s.issuerRegistry == nil {
|
||||
return nil, fmt.Errorf("issuer registry not configured")
|
||||
}
|
||||
|
||||
issuerConn, ok := s.issuerRegistry[issuerID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
||||
}
|
||||
|
||||
revocations, err := s.revocationRepo.ListAll(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list revocations: %w", err)
|
||||
}
|
||||
|
||||
// Filter to this issuer and convert to CRL entries.
|
||||
// Short-lived certificates (profile TTL < 1 hour) are excluded — expiry is sufficient revocation.
|
||||
var entries []CRLEntry
|
||||
for _, rev := range revocations {
|
||||
if rev.IssuerID != issuerID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check short-lived exemption: look up the cert's profile
|
||||
if s.profileRepo != nil && s.certRepo != nil {
|
||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
||||
if err == nil && cert.CertificateProfileID != "" {
|
||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
||||
if err == nil && profile.IsShortLived() {
|
||||
slog.Debug("skipping short-lived cert from CRL",
|
||||
"certificate_id", rev.CertificateID,
|
||||
"profile_id", cert.CertificateProfileID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse serial number from hex string
|
||||
serial := new(big.Int)
|
||||
serial.SetString(rev.SerialNumber, 16)
|
||||
|
||||
entries = append(entries, CRLEntry{
|
||||
SerialNumber: serial,
|
||||
RevokedAt: rev.RevokedAt,
|
||||
ReasonCode: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)),
|
||||
})
|
||||
}
|
||||
|
||||
return issuerConn.GenerateCRL(context.Background(), entries)
|
||||
return s.caSvc.GenerateDERCRL(issuerID)
|
||||
}
|
||||
|
||||
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
||||
// Delegates to CAOperationsSvc.
|
||||
func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
if s.caSvc == nil {
|
||||
return nil, fmt.Errorf("CA operations service not configured")
|
||||
}
|
||||
if s.issuerRegistry == nil {
|
||||
return nil, fmt.Errorf("issuer registry not configured")
|
||||
}
|
||||
|
||||
issuerConn, ok := s.issuerRegistry[issuerID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
||||
}
|
||||
|
||||
serial := new(big.Int)
|
||||
serial.SetString(serialHex, 16)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Short-lived cert exemption: if the cert's profile has TTL < 1 hour,
|
||||
// always return "good" — expiry is sufficient revocation for short-lived certs.
|
||||
if s.profileRepo != nil && s.certRepo != nil {
|
||||
// Look up cert by serial through revocation table
|
||||
rev, _ := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
if rev != nil {
|
||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
||||
if err == nil && cert.CertificateProfileID != "" {
|
||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
||||
if err == nil && profile.IsShortLived() {
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 0, // good — short-lived exemption
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this serial is revoked
|
||||
rev, err := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
if err != nil {
|
||||
// Not revoked — return "good" status
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 0, // good
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
// Revoked
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
CertSerial: serial,
|
||||
CertStatus: 1, // revoked
|
||||
RevokedAt: rev.RevokedAt,
|
||||
RevocationReason: domain.CRLReasonCode(domain.RevocationReason(rev.Reason)),
|
||||
ThisUpdate: now,
|
||||
NextUpdate: now.Add(1 * time.Hour),
|
||||
})
|
||||
return s.caSvc.GetOCSPResponse(issuerID, serialHex)
|
||||
}
|
||||
|
||||
// GetCertificateDeployments returns all deployment targets for a certificate (M20).
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// RevocationSvc provides revocation-related business logic.
|
||||
// It handles certificate revocation, revocation notifications, and issuer coordination.
|
||||
type RevocationSvc struct {
|
||||
certRepo repository.CertificateRepository
|
||||
revocationRepo repository.RevocationRepository
|
||||
auditService *AuditService
|
||||
notificationSvc *NotificationService
|
||||
issuerRegistry map[string]IssuerConnector
|
||||
}
|
||||
|
||||
// NewRevocationSvc creates a new revocation service.
|
||||
func NewRevocationSvc(
|
||||
certRepo repository.CertificateRepository,
|
||||
revocationRepo repository.RevocationRepository,
|
||||
auditService *AuditService,
|
||||
) *RevocationSvc {
|
||||
return &RevocationSvc{
|
||||
certRepo: certRepo,
|
||||
revocationRepo: revocationRepo,
|
||||
auditService: auditService,
|
||||
}
|
||||
}
|
||||
|
||||
// SetNotificationService sets the notification service for revocation alerts.
|
||||
func (s *RevocationSvc) SetNotificationService(svc *NotificationService) {
|
||||
s.notificationSvc = svc
|
||||
}
|
||||
|
||||
// SetIssuerRegistry sets the issuer registry for issuer-level revocation.
|
||||
func (s *RevocationSvc) SetIssuerRegistry(registry map[string]IssuerConnector) {
|
||||
s.issuerRegistry = registry
|
||||
}
|
||||
|
||||
// RevokeCertificateWithActor performs revocation with actor tracking.
|
||||
// Steps:
|
||||
// 1. Validate the certificate exists and is revocable
|
||||
// 2. Get the latest certificate version (for serial number)
|
||||
// 3. Update certificate status to Revoked
|
||||
// 4. Record revocation in certificate_revocations table
|
||||
// 5. Notify the issuer connector (best-effort)
|
||||
// 6. Record audit event
|
||||
// 7. Send revocation notification
|
||||
func (s *RevocationSvc) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
|
||||
// 1. Validate certificate exists and is revocable
|
||||
cert, err := s.certRepo.Get(ctx, certID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch certificate: %w", err)
|
||||
}
|
||||
|
||||
if cert.Status == domain.CertificateStatusRevoked {
|
||||
return fmt.Errorf("certificate is already revoked")
|
||||
}
|
||||
if cert.Status == domain.CertificateStatusArchived {
|
||||
return fmt.Errorf("cannot revoke archived certificate")
|
||||
}
|
||||
|
||||
// Validate reason code
|
||||
if reason == "" {
|
||||
reason = string(domain.RevocationReasonUnspecified)
|
||||
}
|
||||
if !domain.IsValidRevocationReason(reason) {
|
||||
return fmt.Errorf("invalid revocation reason: %s", reason)
|
||||
}
|
||||
|
||||
// 2. Get latest certificate version for serial number
|
||||
version, err := s.certRepo.GetLatestVersion(ctx, certID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get certificate version: %w", err)
|
||||
}
|
||||
|
||||
// 3. Update certificate status to Revoked
|
||||
now := time.Now()
|
||||
cert.Status = domain.CertificateStatusRevoked
|
||||
cert.RevokedAt = &now
|
||||
cert.RevocationReason = reason
|
||||
cert.UpdatedAt = now
|
||||
if err := s.certRepo.Update(ctx, cert); err != nil {
|
||||
return fmt.Errorf("failed to update certificate status: %w", err)
|
||||
}
|
||||
|
||||
// 4. Record revocation in certificate_revocations table (for CRL generation)
|
||||
if s.revocationRepo != nil {
|
||||
revocation := &domain.CertificateRevocation{
|
||||
ID: generateID("rev"),
|
||||
CertificateID: certID,
|
||||
SerialNumber: version.SerialNumber,
|
||||
Reason: reason,
|
||||
RevokedBy: actor,
|
||||
RevokedAt: now,
|
||||
IssuerID: cert.IssuerID,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := s.revocationRepo.Create(ctx, revocation); err != nil {
|
||||
slog.Error("failed to record revocation for CRL", "error", err, "certificate_id", certID)
|
||||
// Don't fail the overall revocation — the cert status is already updated
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Notify the issuer connector (best-effort)
|
||||
if s.issuerRegistry != nil {
|
||||
if issuerConn, ok := s.issuerRegistry[cert.IssuerID]; ok {
|
||||
if err := issuerConn.RevokeCertificate(ctx, version.SerialNumber, reason); err != nil {
|
||||
slog.Error("failed to notify issuer of revocation",
|
||||
"error", err,
|
||||
"issuer_id", cert.IssuerID,
|
||||
"serial", version.SerialNumber)
|
||||
// Best-effort — don't fail the overall revocation
|
||||
} else if s.revocationRepo != nil {
|
||||
// Mark issuer as notified
|
||||
revocations, _ := s.revocationRepo.ListByCertificate(ctx, certID)
|
||||
for _, rev := range revocations {
|
||||
if rev.SerialNumber == version.SerialNumber {
|
||||
_ = s.revocationRepo.MarkIssuerNotified(ctx, rev.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Record audit event
|
||||
if err := s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
|
||||
"certificate_revoked", "certificate", certID,
|
||||
map[string]interface{}{
|
||||
"common_name": cert.CommonName,
|
||||
"serial": version.SerialNumber,
|
||||
"reason": reason,
|
||||
}); err != nil {
|
||||
slog.Error("failed to record audit event", "error", err)
|
||||
}
|
||||
|
||||
// 7. Send revocation notification
|
||||
if s.notificationSvc != nil {
|
||||
if err := s.notificationSvc.SendRevocationNotification(ctx, cert, reason); err != nil {
|
||||
slog.Error("failed to send revocation notification", "error", err, "certificate_id", certID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
||||
func (s *RevocationSvc) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
|
||||
if s.revocationRepo == nil {
|
||||
return nil, fmt.Errorf("revocation repository not configured")
|
||||
}
|
||||
return s.revocationRepo.ListAll(context.Background())
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
// Tests for RevocationSvc, the focused sub-service that handles certificate
|
||||
// revocation logic extracted from CertificateService (TICKET-007).
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// helper to create a RevocationSvc for testing
|
||||
func newRevocationSvcTest() (*RevocationSvc, *mockCertRepo, *mockRevocationRepo, *mockAuditRepo) {
|
||||
certRepo := newMockCertificateRepository()
|
||||
revocationRepo := newMockRevocationRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
|
||||
auditService := NewAuditService(auditRepo)
|
||||
revSvc := NewRevocationSvc(certRepo, revocationRepo, auditService)
|
||||
revSvc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
"iss-local": &mockIssuerConnector{},
|
||||
})
|
||||
|
||||
return revSvc, certRepo, revocationRepo, auditRepo
|
||||
}
|
||||
|
||||
func TestRevocationSvc_RevokeCertificateWithActor_Success(t *testing.T) {
|
||||
revSvc, certRepo, revocationRepo, auditRepo := newRevocationSvcTest()
|
||||
|
||||
// Set up test data
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-1",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-local",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 6, 0),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Add a certificate version with a serial number
|
||||
version := &domain.CertificateVersion{
|
||||
ID: "ver-1",
|
||||
CertificateID: "cert-1",
|
||||
SerialNumber: "ABC123",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
|
||||
|
||||
// Revoke
|
||||
err := revSvc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify certificate status changed
|
||||
updated, _ := certRepo.Get(context.Background(), "cert-1")
|
||||
if updated.Status != domain.CertificateStatusRevoked {
|
||||
t.Errorf("expected status Revoked, got %s", updated.Status)
|
||||
}
|
||||
if updated.RevokedAt == nil {
|
||||
t.Error("expected RevokedAt to be set")
|
||||
}
|
||||
if updated.RevocationReason != "keyCompromise" {
|
||||
t.Errorf("expected reason keyCompromise, got %s", updated.RevocationReason)
|
||||
}
|
||||
|
||||
// Verify revocation record created
|
||||
if len(revocationRepo.Revocations) != 1 {
|
||||
t.Fatalf("expected 1 revocation record, got %d", len(revocationRepo.Revocations))
|
||||
}
|
||||
rev := revocationRepo.Revocations[0]
|
||||
if rev.SerialNumber != "ABC123" {
|
||||
t.Errorf("expected serial ABC123, got %s", rev.SerialNumber)
|
||||
}
|
||||
if rev.Reason != "keyCompromise" {
|
||||
t.Errorf("expected reason keyCompromise, got %s", rev.Reason)
|
||||
}
|
||||
if rev.RevokedBy != "admin" {
|
||||
t.Errorf("expected revokedBy admin, got %s", rev.RevokedBy)
|
||||
}
|
||||
|
||||
// Verify audit event recorded
|
||||
if len(auditRepo.Events) == 0 {
|
||||
t.Error("expected audit event to be recorded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevocationSvc_RevokeCertificateWithActor_AlreadyRevoked(t *testing.T) {
|
||||
revSvc, certRepo, _, _ := newRevocationSvcTest()
|
||||
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-3",
|
||||
CommonName: "already-revoked.com",
|
||||
IssuerID: "iss-local",
|
||||
Status: domain.CertificateStatusRevoked,
|
||||
RevokedAt: &now,
|
||||
RevocationReason: "keyCompromise",
|
||||
ExpiresAt: time.Now().AddDate(0, 6, 0),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
err := revSvc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for already revoked certificate")
|
||||
}
|
||||
if err.Error() != "certificate is already revoked" {
|
||||
t.Errorf("expected 'already revoked' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevocationSvc_GetRevokedCertificates_Success(t *testing.T) {
|
||||
revSvc, _, revocationRepo, _ := newRevocationSvcTest()
|
||||
|
||||
// Pre-populate revocation records
|
||||
revocationRepo.Revocations = []*domain.CertificateRevocation{
|
||||
{ID: "rev-1", CertificateID: "cert-1", SerialNumber: "SER-1", Reason: "keyCompromise", RevokedAt: time.Now()},
|
||||
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
|
||||
}
|
||||
|
||||
revocations, err := revSvc.GetRevokedCertificates()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if len(revocations) != 2 {
|
||||
t.Errorf("expected 2 revocations, got %d", len(revocations))
|
||||
}
|
||||
}
|
||||
@@ -14,15 +14,27 @@ func newRevocationTestService() (*CertificateService, *mockCertRepo, *mockRevoca
|
||||
auditRepo := newMockAuditRepository()
|
||||
policyRepo := newMockPolicyRepository()
|
||||
revocationRepo := newMockRevocationRepository()
|
||||
profileRepo := newMockProfileRepository()
|
||||
|
||||
auditService := NewAuditService(auditRepo)
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
certService.SetRevocationRepo(revocationRepo)
|
||||
certService.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
|
||||
// Create RevocationSvc
|
||||
revSvc := NewRevocationSvc(certRepo, revocationRepo, auditService)
|
||||
revSvc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
"iss-local": &mockIssuerConnector{},
|
||||
})
|
||||
|
||||
// Create CAOperationsSvc
|
||||
caSvc := NewCAOperationsSvc(revocationRepo, certRepo, profileRepo)
|
||||
caSvc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
"iss-local": &mockIssuerConnector{},
|
||||
})
|
||||
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
certService.SetRevocationSvc(revSvc)
|
||||
certService.SetCAOperationsSvc(caSvc)
|
||||
|
||||
return certService, certRepo, revocationRepo, auditRepo
|
||||
}
|
||||
|
||||
@@ -229,9 +241,9 @@ func TestRevokeCertificate_NoVersion(t *testing.T) {
|
||||
func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
|
||||
svc, certRepo, revocationRepo, _ := newRevocationTestService()
|
||||
|
||||
// Wire up issuer registry with mock
|
||||
// Wire up issuer registry on RevocationSvc with mock
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
svc.revSvc.SetIssuerRegistry(map[string]IssuerConnector{
|
||||
"iss-local": mockIssuer,
|
||||
})
|
||||
|
||||
@@ -264,10 +276,10 @@ func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
|
||||
func TestRevokeCertificate_WithNotificationService(t *testing.T) {
|
||||
svc, certRepo, _, _ := newRevocationTestService()
|
||||
|
||||
// Wire up notification service
|
||||
// Wire up notification service on RevocationSvc
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifService := NewNotificationService(notifRepo, make(map[string]Notifier))
|
||||
svc.SetNotificationService(notifService)
|
||||
svc.revSvc.SetNotificationService(notifService)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-8",
|
||||
|
||||
Reference in New Issue
Block a user