feat: M11a — certificate profiles, crypto policy enforcement, short-lived cert expiry

Add certificate profiles as named enrollment templates that control allowed
key algorithms, max TTL, permitted EKUs, required SAN patterns, and optional
SPIFFE URI SANs. CSR submissions are validated against profile rules at
signing time (key type + minimum size). Short-lived certs (TTL < 1 hour)
auto-expire via a new scheduler loop — expiry acts as revocation, no
CRL/OCSP needed.

New files:
- Migration 000003: certificate_profiles table, FK columns on
  managed_certificates/renewal_policies, key metadata on certificate_versions
- domain/profile.go: CertificateProfile + KeyAlgorithmRule structs
- repository/postgres/profile.go: full CRUD with JSONB marshaling
- service/profile.go: ProfileService with validation + audit logging
- service/crypto_validation.go: CSR-against-profile validation (RSA/ECDSA/Ed25519)
- handler/profiles.go: 5 HTTP endpoints under /api/v1/profiles
- web/src/pages/ProfilesPage.tsx: profiles management page

Modified:
- renewal.go: CSR validation in CompleteAgentCSRRenewal, ExpireShortLivedCertificates
- scheduler.go: 30s short-lived expiry check loop
- certificate.go (repo): nullable profile FK, key metadata on versions
- main.go: profile repo/service/handler wiring, 8-param NewRenewalService
- router.go: 12-param RegisterHandlers with profile routes
- seed_demo.sql: 4 demo profiles (standard, mtls, short-lived, high-security)
- Frontend: types, API client, routing, sidebar nav

Tests: 40 new tests across handler (15), service (13), crypto validation (12)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-03-20 20:39:49 -04:00
parent 7450fcfb07
commit a579a84c7f
27 changed files with 2399 additions and 71 deletions
+85
View File
@@ -0,0 +1,85 @@
package service
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/shankar0123/certctl/internal/domain"
)
// CSRValidationResult contains metadata extracted from a validated CSR.
type CSRValidationResult struct {
KeyAlgorithm string
KeySize int
}
// ValidateCSRAgainstProfile parses a CSR PEM and validates that its key algorithm
// and size comply with the profile's allowed_key_algorithms rules.
// Returns extracted key metadata on success for storage in certificate_versions.
func ValidateCSRAgainstProfile(csrPEM string, profile *domain.CertificateProfile) (*CSRValidationResult, error) {
if profile == nil {
// No profile assigned — skip validation, extract metadata only
return extractCSRKeyInfo(csrPEM)
}
result, err := extractCSRKeyInfo(csrPEM)
if err != nil {
return nil, err
}
// Check that the CSR's key algorithm + size matches at least one allowed rule
if len(profile.AllowedKeyAlgorithms) == 0 {
// No restrictions defined — allow anything
return result, nil
}
for _, rule := range profile.AllowedKeyAlgorithms {
if rule.Algorithm == result.KeyAlgorithm && result.KeySize >= rule.MinSize {
return result, nil
}
}
return nil, fmt.Errorf("CSR key (%s %d-bit) does not match any allowed algorithm in profile %q: %v",
result.KeyAlgorithm, result.KeySize, profile.Name, profile.AllowedKeyAlgorithms)
}
// extractCSRKeyInfo parses a CSR PEM and extracts the key algorithm and size.
func extractCSRKeyInfo(csrPEM string) (*CSRValidationResult, error) {
block, _ := pem.Decode([]byte(csrPEM))
if block == nil {
return nil, fmt.Errorf("failed to decode CSR PEM")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse CSR: %w", err)
}
if err := csr.CheckSignature(); err != nil {
return nil, fmt.Errorf("CSR signature verification failed: %w", err)
}
switch key := csr.PublicKey.(type) {
case *rsa.PublicKey:
return &CSRValidationResult{
KeyAlgorithm: domain.KeyAlgorithmRSA,
KeySize: key.N.BitLen(),
}, nil
case *ecdsa.PublicKey:
return &CSRValidationResult{
KeyAlgorithm: domain.KeyAlgorithmECDSA,
KeySize: key.Curve.Params().BitSize,
}, nil
case ed25519.PublicKey:
return &CSRValidationResult{
KeyAlgorithm: domain.KeyAlgorithmEd25519,
KeySize: 256, // Ed25519 is fixed 256-bit
}, nil
default:
return nil, fmt.Errorf("unsupported key type in CSR: %T", csr.PublicKey)
}
}
+244
View File
@@ -0,0 +1,244 @@
package service
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"testing"
"github.com/shankar0123/certctl/internal/domain"
)
// generateTestCSR creates a valid CSR PEM for testing purposes.
func generateTestCSR(t *testing.T, keyType string, keySize int) string {
t.Helper()
var privKey interface{}
var err error
switch keyType {
case "RSA":
privKey, err = rsa.GenerateKey(rand.Reader, keySize)
case "ECDSA":
var curve elliptic.Curve
switch keySize {
case 256:
curve = elliptic.P256()
case 384:
curve = elliptic.P384()
default:
t.Fatalf("unsupported ECDSA key size: %d", keySize)
}
privKey, err = ecdsa.GenerateKey(curve, rand.Reader)
default:
t.Fatalf("unsupported key type: %s", keyType)
}
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "test.example.com",
},
DNSNames: []string{"test.example.com", "www.example.com"},
}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, privKey)
if err != nil {
t.Fatalf("failed to create CSR: %v", err)
}
csrPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrDER,
})
return string(csrPEM)
}
func TestValidateCSRAgainstProfile_NilProfile(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 256)
result, err := ValidateCSRAgainstProfile(csrPEM, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "ECDSA" {
t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm)
}
if result.KeySize != 256 {
t.Errorf("expected 256, got %d", result.KeySize)
}
}
func TestValidateCSRAgainstProfile_ECDSA256_Allowed(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 256)
profile := &domain.CertificateProfile{
Name: "Standard TLS",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "ECDSA", MinSize: 256},
{Algorithm: "RSA", MinSize: 2048},
},
}
result, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "ECDSA" {
t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm)
}
if result.KeySize != 256 {
t.Errorf("expected 256, got %d", result.KeySize)
}
}
func TestValidateCSRAgainstProfile_ECDSA384_Allowed(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 384)
profile := &domain.CertificateProfile{
Name: "High Security",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "ECDSA", MinSize: 384},
},
}
result, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeySize != 384 {
t.Errorf("expected 384, got %d", result.KeySize)
}
}
func TestValidateCSRAgainstProfile_RSA2048_Allowed(t *testing.T) {
csrPEM := generateTestCSR(t, "RSA", 2048)
profile := &domain.CertificateProfile{
Name: "Standard TLS",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "RSA", MinSize: 2048},
},
}
result, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "RSA" {
t.Errorf("expected RSA, got %s", result.KeyAlgorithm)
}
if result.KeySize != 2048 {
t.Errorf("expected 2048, got %d", result.KeySize)
}
}
func TestValidateCSRAgainstProfile_ECDSA256_RejectedByHighSecurity(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 256)
profile := &domain.CertificateProfile{
Name: "High Security",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "ECDSA", MinSize: 384},
{Algorithm: "RSA", MinSize: 4096},
},
}
_, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err == nil {
t.Fatal("expected rejection, got nil error")
}
if !containsSubstring(err.Error(), "does not match any allowed algorithm") {
t.Errorf("unexpected error message: %s", err.Error())
}
}
func TestValidateCSRAgainstProfile_RSA_RejectedByECDSAOnly(t *testing.T) {
csrPEM := generateTestCSR(t, "RSA", 2048)
profile := &domain.CertificateProfile{
Name: "ECDSA Only",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "ECDSA", MinSize: 256},
},
}
_, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err == nil {
t.Fatal("expected rejection, got nil error")
}
}
func TestValidateCSRAgainstProfile_EmptyAlgorithmRules(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 256)
profile := &domain.CertificateProfile{
Name: "Permissive",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{}, // empty = allow anything
}
result, err := ValidateCSRAgainstProfile(csrPEM, profile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "ECDSA" {
t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm)
}
}
func TestValidateCSRAgainstProfile_InvalidPEM(t *testing.T) {
_, err := ValidateCSRAgainstProfile("not a pem", nil)
if err == nil {
t.Fatal("expected error for invalid PEM, got nil")
}
if !containsSubstring(err.Error(), "failed to decode CSR PEM") {
t.Errorf("unexpected error: %s", err.Error())
}
}
func TestValidateCSRAgainstProfile_InvalidCSRContent(t *testing.T) {
// Valid PEM block but garbage content
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nTm90IGEgcmVhbCBDU1I=\n-----END CERTIFICATE REQUEST-----"
_, err := ValidateCSRAgainstProfile(csrPEM, nil)
if err == nil {
t.Fatal("expected error for invalid CSR content, got nil")
}
}
func TestExtractCSRKeyInfo_ECDSA(t *testing.T) {
csrPEM := generateTestCSR(t, "ECDSA", 256)
result, err := extractCSRKeyInfo(csrPEM)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "ECDSA" {
t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm)
}
if result.KeySize != 256 {
t.Errorf("expected 256, got %d", result.KeySize)
}
}
func TestExtractCSRKeyInfo_RSA(t *testing.T) {
csrPEM := generateTestCSR(t, "RSA", 2048)
result, err := extractCSRKeyInfo(csrPEM)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.KeyAlgorithm != "RSA" {
t.Errorf("expected RSA, got %s", result.KeyAlgorithm)
}
if result.KeySize != 2048 {
t.Errorf("expected 2048, got %d", result.KeySize)
}
}
+1 -1
View File
@@ -28,7 +28,7 @@ func newTestJobService(jobRepo *mockJobRepo) *JobService {
targetRepo := &mockTargetRepo{Targets: make(map[string]*domain.DeploymentTarget)}
agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent)}
renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notifService, make(map[string]IssuerConnector), "server")
renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notifService, make(map[string]IssuerConnector), "server")
deploymentService := NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notifService)
return NewJobService(jobRepo, renewalService, deploymentService, logger)
+181
View File
@@ -0,0 +1,181 @@
package service
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// ProfileService provides business logic for certificate profile management.
type ProfileService struct {
profileRepo repository.CertificateProfileRepository
auditService *AuditService
}
// NewProfileService creates a new profile service.
func NewProfileService(
profileRepo repository.CertificateProfileRepository,
auditService *AuditService,
) *ProfileService {
return &ProfileService{
profileRepo: profileRepo,
auditService: auditService,
}
}
// ListProfiles returns all profiles (handler interface method).
func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) {
if page < 1 {
page = 1
}
if perPage < 1 {
perPage = 50
}
profiles, err := s.profileRepo.List(context.Background())
if err != nil {
return nil, 0, fmt.Errorf("failed to list profiles: %w", err)
}
total := int64(len(profiles))
var result []domain.CertificateProfile
for _, p := range profiles {
if p != nil {
result = append(result, *p)
}
}
return result, total, nil
}
// GetProfile returns a single profile (handler interface method).
func (s *ProfileService) GetProfile(id string) (*domain.CertificateProfile, error) {
return s.profileRepo.Get(context.Background(), id)
}
// CreateProfile creates a new profile with validation (handler interface method).
func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if err := validateProfile(&profile); err != nil {
return nil, err
}
if profile.ID == "" {
profile.ID = generateID("prof")
}
now := time.Now()
if profile.CreatedAt.IsZero() {
profile.CreatedAt = now
}
if profile.UpdatedAt.IsZero() {
profile.UpdatedAt = now
}
// Apply defaults if not set
if len(profile.AllowedKeyAlgorithms) == 0 {
profile.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms()
}
if len(profile.AllowedEKUs) == 0 {
profile.AllowedEKUs = domain.DefaultEKUs()
}
if err := s.profileRepo.Create(context.Background(), &profile); err != nil {
return nil, fmt.Errorf("failed to create profile: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
"create_profile", "certificate_profile", profile.ID, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
}
return &profile, nil
}
// UpdateProfile modifies an existing profile (handler interface method).
func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if err := validateProfile(&profile); err != nil {
return nil, err
}
profile.ID = id
if err := s.profileRepo.Update(context.Background(), &profile); err != nil {
return nil, fmt.Errorf("failed to update profile: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
"update_profile", "certificate_profile", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
}
return &profile, nil
}
// DeleteProfile removes a profile (handler interface method).
func (s *ProfileService) DeleteProfile(id string) error {
if err := s.profileRepo.Delete(context.Background(), id); err != nil {
return fmt.Errorf("failed to delete profile: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
"delete_profile", "certificate_profile", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
}
return nil
}
// Get retrieves a profile by ID (used by other services like RenewalService).
func (s *ProfileService) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) {
return s.profileRepo.Get(ctx, id)
}
// validateProfile checks that a profile's configuration is valid.
func validateProfile(p *domain.CertificateProfile) error {
if p.Name == "" {
return fmt.Errorf("profile name is required")
}
if len(p.Name) > 255 {
return fmt.Errorf("profile name exceeds 255 characters")
}
// Validate key algorithms
for _, alg := range p.AllowedKeyAlgorithms {
if !domain.ValidKeyAlgorithms[alg.Algorithm] {
return fmt.Errorf("invalid key algorithm: %s (allowed: RSA, ECDSA, Ed25519)", alg.Algorithm)
}
if alg.Algorithm == domain.KeyAlgorithmRSA && alg.MinSize < 2048 {
return fmt.Errorf("RSA minimum key size must be at least 2048, got %d", alg.MinSize)
}
if alg.Algorithm == domain.KeyAlgorithmECDSA && alg.MinSize < 256 {
return fmt.Errorf("ECDSA minimum key size must be at least 256, got %d", alg.MinSize)
}
}
// Validate EKUs
for _, eku := range p.AllowedEKUs {
if !domain.ValidEKUs[eku] {
return fmt.Errorf("invalid EKU: %s", eku)
}
}
// Validate max TTL
if p.MaxTTLSeconds < 0 {
return fmt.Errorf("max_ttl_seconds cannot be negative")
}
// Validate short-lived consistency
if p.AllowShortLived && p.MaxTTLSeconds >= 3600 {
return fmt.Errorf("allow_short_lived is true but max_ttl_seconds (%d) is >= 3600; short-lived certs must have TTL under 1 hour", p.MaxTTLSeconds)
}
return nil
}
+415
View File
@@ -0,0 +1,415 @@
package service
import (
"context"
"errors"
"testing"
"github.com/shankar0123/certctl/internal/domain"
)
// mockProfileRepo is a test implementation of CertificateProfileRepository
type mockProfileRepo struct {
profiles map[string]*domain.CertificateProfile
ListErr error
GetErr error
CreateErr error
UpdateErr error
DeleteErr error
}
func newMockProfileRepository() *mockProfileRepo {
return &mockProfileRepo{
profiles: make(map[string]*domain.CertificateProfile),
}
}
func (m *mockProfileRepo) List(ctx context.Context) ([]*domain.CertificateProfile, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var profiles []*domain.CertificateProfile
for _, p := range m.profiles {
profiles = append(profiles, p)
}
return profiles, nil
}
func (m *mockProfileRepo) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
p, ok := m.profiles[id]
if !ok {
return nil, errNotFound
}
return p, nil
}
func (m *mockProfileRepo) Create(ctx context.Context, profile *domain.CertificateProfile) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.profiles[profile.ID] = profile
return nil
}
func (m *mockProfileRepo) Update(ctx context.Context, profile *domain.CertificateProfile) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.profiles[profile.ID] = profile
return nil
}
func (m *mockProfileRepo) Delete(ctx context.Context, id string) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
delete(m.profiles, id)
return nil
}
func (m *mockProfileRepo) AddProfile(p *domain.CertificateProfile) {
m.profiles[p.ID] = p
}
// --- ProfileService Tests ---
func TestProfileService_ListProfiles(t *testing.T) {
repo := newMockProfileRepository()
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS", Enabled: true})
repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true})
svc := NewProfileService(repo, nil)
profiles, total, err := svc.ListProfiles(1, 50)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
if len(profiles) != 2 {
t.Errorf("expected 2 profiles, got %d", len(profiles))
}
}
func TestProfileService_ListProfiles_Empty(t *testing.T) {
repo := newMockProfileRepository()
svc := NewProfileService(repo, nil)
profiles, total, err := svc.ListProfiles(1, 50)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if total != 0 {
t.Errorf("expected total 0, got %d", total)
}
if len(profiles) != 0 {
t.Errorf("expected 0 profiles, got %d", len(profiles))
}
}
func TestProfileService_ListProfiles_RepoError(t *testing.T) {
repo := newMockProfileRepository()
repo.ListErr = errors.New("db error")
svc := NewProfileService(repo, nil)
_, _, err := svc.ListProfiles(1, 50)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProfileService_GetProfile(t *testing.T) {
repo := newMockProfileRepository()
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"})
svc := NewProfileService(repo, nil)
profile, err := svc.GetProfile("prof-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if profile.Name != "Standard TLS" {
t.Errorf("expected 'Standard TLS', got '%s'", profile.Name)
}
}
func TestProfileService_GetProfile_NotFound(t *testing.T) {
repo := newMockProfileRepository()
svc := NewProfileService(repo, nil)
_, err := svc.GetProfile("nonexistent")
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProfileService_CreateProfile_Defaults(t *testing.T) {
repo := newMockProfileRepository()
auditRepo := newMockAuditRepository()
auditSvc := NewAuditService(auditRepo)
svc := NewProfileService(repo, auditSvc)
profile := domain.CertificateProfile{
Name: "New Profile",
MaxTTLSeconds: 86400,
}
created, err := svc.CreateProfile(profile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if created.ID == "" {
t.Error("expected generated ID, got empty")
}
if len(created.AllowedKeyAlgorithms) == 0 {
t.Error("expected default key algorithms, got empty")
}
if len(created.AllowedEKUs) == 0 {
t.Error("expected default EKUs, got empty")
}
if created.CreatedAt.IsZero() {
t.Error("expected CreatedAt to be set")
}
// Verify audit event recorded
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestProfileService_CreateProfile_ValidationErrors(t *testing.T) {
repo := newMockProfileRepository()
svc := NewProfileService(repo, nil)
tests := []struct {
name string
profile domain.CertificateProfile
errMsg string
}{
{
name: "empty name",
profile: domain.CertificateProfile{},
errMsg: "profile name is required",
},
{
name: "name too long",
profile: domain.CertificateProfile{
Name: string(make([]byte, 256)),
},
errMsg: "exceeds 255 characters",
},
{
name: "invalid key algorithm",
profile: domain.CertificateProfile{
Name: "Bad Algo",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "DES", MinSize: 56},
},
},
errMsg: "invalid key algorithm",
},
{
name: "RSA key too small",
profile: domain.CertificateProfile{
Name: "Weak RSA",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "RSA", MinSize: 1024},
},
},
errMsg: "RSA minimum key size must be at least 2048",
},
{
name: "ECDSA key too small",
profile: domain.CertificateProfile{
Name: "Weak ECDSA",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "ECDSA", MinSize: 128},
},
},
errMsg: "ECDSA minimum key size must be at least 256",
},
{
name: "invalid EKU",
profile: domain.CertificateProfile{
Name: "Bad EKU",
AllowedEKUs: []string{"invalidEKU"},
},
errMsg: "invalid EKU",
},
{
name: "negative TTL",
profile: domain.CertificateProfile{
Name: "Negative TTL",
MaxTTLSeconds: -1,
},
errMsg: "cannot be negative",
},
{
name: "short-lived with long TTL",
profile: domain.CertificateProfile{
Name: "Inconsistent Short-Lived",
AllowShortLived: true,
MaxTTLSeconds: 7200,
},
errMsg: "short-lived certs must have TTL under 1 hour",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := svc.CreateProfile(tt.profile)
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.errMsg)
}
if !contains(err.Error(), tt.errMsg) {
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
}
})
}
}
func TestProfileService_CreateProfile_RepoError(t *testing.T) {
repo := newMockProfileRepository()
repo.CreateErr = errors.New("db create failed")
svc := NewProfileService(repo, nil)
_, err := svc.CreateProfile(domain.CertificateProfile{Name: "Valid"})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProfileService_UpdateProfile(t *testing.T) {
repo := newMockProfileRepository()
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Original"})
auditRepo := newMockAuditRepository()
auditSvc := NewAuditService(auditRepo)
svc := NewProfileService(repo, auditSvc)
updated, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{
Name: "Updated",
MaxTTLSeconds: 43200,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if updated.ID != "prof-1" {
t.Errorf("expected ID 'prof-1', got '%s'", updated.ID)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestProfileService_UpdateProfile_ValidationError(t *testing.T) {
repo := newMockProfileRepository()
svc := NewProfileService(repo, nil)
_, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{Name: ""})
if err == nil {
t.Fatal("expected validation error, got nil")
}
}
func TestProfileService_DeleteProfile(t *testing.T) {
repo := newMockProfileRepository()
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "To Delete"})
auditRepo := newMockAuditRepository()
auditSvc := NewAuditService(auditRepo)
svc := NewProfileService(repo, auditSvc)
err := svc.DeleteProfile("prof-1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestProfileService_DeleteProfile_RepoError(t *testing.T) {
repo := newMockProfileRepository()
repo.DeleteErr = errors.New("db delete failed")
svc := NewProfileService(repo, nil)
err := svc.DeleteProfile("prof-1")
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProfileService_CreateProfile_ValidShortLived(t *testing.T) {
repo := newMockProfileRepository()
svc := NewProfileService(repo, nil)
// Short-lived with TTL under 1 hour should succeed
created, err := svc.CreateProfile(domain.CertificateProfile{
Name: "CI Ephemeral",
AllowShortLived: true,
MaxTTLSeconds: 300, // 5 minutes
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !created.AllowShortLived {
t.Error("expected AllowShortLived to be true")
}
}
func TestIsShortLived(t *testing.T) {
tests := []struct {
name string
profile domain.CertificateProfile
expected bool
}{
{
name: "short-lived with 5 min TTL",
profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 300},
expected: true,
},
{
name: "short-lived flag false",
profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 300},
expected: false,
},
{
name: "zero TTL with flag",
profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 0},
expected: false,
},
{
name: "TTL at 1 hour boundary",
profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 3600},
expected: false,
},
{
name: "standard long-lived",
profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 7776000},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.profile.IsShortLived()
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}
// contains checks if a string contains a substring (helper for test assertions).
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+92
View File
@@ -22,6 +22,7 @@ type RenewalService struct {
certRepo repository.CertificateRepository
jobRepo repository.JobRepository
renewalPolicyRepo repository.RenewalPolicyRepository
profileRepo repository.CertificateProfileRepository
auditService *AuditService
notificationSvc *NotificationService
issuerRegistry map[string]IssuerConnector
@@ -52,6 +53,7 @@ func NewRenewalService(
certRepo repository.CertificateRepository,
jobRepo repository.JobRepository,
renewalPolicyRepo repository.RenewalPolicyRepository,
profileRepo repository.CertificateProfileRepository,
auditService *AuditService,
notificationSvc *NotificationService,
issuerRegistry map[string]IssuerConnector,
@@ -64,6 +66,7 @@ func NewRenewalService(
certRepo: certRepo,
jobRepo: jobRepo,
renewalPolicyRepo: renewalPolicyRepo,
profileRepo: profileRepo,
auditService: auditService,
notificationSvc: notificationSvc,
issuerRegistry: issuerRegistry,
@@ -371,6 +374,8 @@ func (s *RenewalService) processRenewalServerKeygen(ctx context.Context, job *do
FingerprintSHA256: fingerprint,
PEMChain: result.CertPEM + "\n" + result.ChainPEM,
CSRPEM: privKeyPEM, // Server mode: stores private key for agent deployment
KeyAlgorithm: domain.KeyAlgorithmRSA,
KeySize: 2048,
CreatedAt: time.Now(),
}
@@ -428,6 +433,22 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai
return fmt.Errorf("issuer connector not found for %s", cert.IssuerID)
}
// Validate CSR against certificate profile (crypto policy enforcement)
var profile *domain.CertificateProfile
if cert.CertificateProfileID != "" && s.profileRepo != nil {
var profileErr error
profile, profileErr = s.profileRepo.Get(ctx, cert.CertificateProfileID)
if profileErr != nil {
slog.Warn("failed to fetch certificate profile, skipping crypto validation",
"profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", profileErr)
}
}
csrInfo, csrErr := ValidateCSRAgainstProfile(csrPEM, profile)
if csrErr != nil {
s.failJob(ctx, job, fmt.Sprintf("CSR validation failed: %v", csrErr))
return fmt.Errorf("CSR validation failed: %w", csrErr)
}
// Update job to running
if err := s.jobRepo.UpdateStatus(ctx, job.ID, domain.JobStatusRunning, ""); err != nil {
return fmt.Errorf("failed to update job status: %w", err)
@@ -462,6 +483,10 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai
CSRPEM: csrPEM, // Agent mode: stores actual CSR, not private key
CreatedAt: time.Now(),
}
if csrInfo != nil {
version.KeyAlgorithm = csrInfo.KeyAlgorithm
version.KeySize = csrInfo.KeySize
}
if err := s.certRepo.CreateVersion(ctx, version); err != nil {
s.failJob(ctx, job, fmt.Sprintf("version creation failed: %v", err))
@@ -589,6 +614,73 @@ func (s *RenewalService) RetryFailedJobs(ctx context.Context, maxRetries int) er
return nil
}
// ExpireShortLivedCertificates finds active certificates with short-lived profiles
// whose TTL has elapsed and marks them as Expired. For certs with TTL < 1 hour,
// expiry is the revocation mechanism — no CRL/OCSP needed.
func (s *RenewalService) ExpireShortLivedCertificates(ctx context.Context) error {
if s.profileRepo == nil {
return nil
}
// Get all Active certificates and check if any have expired based on their actual expiry time
// This catches short-lived certs that expire between normal renewal check cycles
now := time.Now()
expiring, err := s.certRepo.GetExpiringCertificates(ctx, now)
if err != nil {
return fmt.Errorf("failed to fetch expired certificates: %w", err)
}
for _, cert := range expiring {
if cert.Status != domain.CertificateStatusActive && cert.Status != domain.CertificateStatusExpiring {
continue
}
// Only auto-expire certs that have actually passed their expiry time
if cert.ExpiresAt.After(now) {
continue
}
// Check if this cert has a short-lived profile
if cert.CertificateProfileID == "" {
continue
}
profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
if err != nil {
slog.Warn("failed to fetch profile for short-lived expiry check",
"profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", err)
continue
}
if !profile.IsShortLived() {
continue
}
// Mark as expired
cert.Status = domain.CertificateStatusExpired
cert.UpdatedAt = now
if err := s.certRepo.Update(ctx, cert); err != nil {
slog.Error("failed to expire short-lived cert", "cert_id", cert.ID, "error", err)
continue
}
slog.Info("short-lived certificate expired (expiry = revocation)",
"cert_id", cert.ID, "profile_id", cert.CertificateProfileID,
"expired_at", cert.ExpiresAt)
if auditErr := s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
"short_lived_cert_expired", "certificate", cert.ID,
map[string]interface{}{
"profile_id": cert.CertificateProfileID,
"expired_at": cert.ExpiresAt,
}); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
}
return nil
}
// generateID is a helper to generate unique IDs. In production, use a proper ID generator.
func generateID(prefix string) string {
return fmt.Sprintf("%s-%d", prefix, time.Now().UnixNano())
+12 -12
View File
@@ -30,7 +30,7 @@ func TestCheckExpiringCertificates_SendsThresholdAlerts(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create a cert expiring in 10 days
cert := &domain.ManagedCertificate{
@@ -112,7 +112,7 @@ func TestCheckExpiringCertificates_DeduplicatesAlerts(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create cert
cert := &domain.ManagedCertificate{
@@ -192,7 +192,7 @@ func TestCheckExpiringCertificates_SkipsRenewalInProgress(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create cert with RenewalInProgress status
cert := &domain.ManagedCertificate{
@@ -257,7 +257,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpiring(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create active cert that will become expiring
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
@@ -319,7 +319,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpired(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create cert that is already expired
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
@@ -381,7 +381,7 @@ func TestCheckExpiringCertificates_CreatesRenewalJob(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create expiring cert with registered issuer
cert := &domain.ManagedCertificate{
@@ -447,7 +447,7 @@ func TestCheckExpiringCertificates_SkipsWithoutIssuer(t *testing.T) {
// Empty issuer registry
issuerRegistry := map[string]IssuerConnector{}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create cert with unregistered issuer
cert := &domain.ManagedCertificate{
@@ -509,7 +509,7 @@ func TestCheckExpiringCertificates_SkipsDuplicateJobs(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create cert
cert := &domain.ManagedCertificate{
@@ -593,7 +593,7 @@ func TestProcessRenewalJob(t *testing.T) {
"iss-test": issuerConnector,
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create certificate
cert := &domain.ManagedCertificate{
@@ -689,7 +689,7 @@ func TestProcessRenewalJob_IssuerFailure(t *testing.T) {
"iss-test": issuerConnector,
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create certificate
cert := &domain.ManagedCertificate{
@@ -771,7 +771,7 @@ func TestRetryFailedJobs(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create failed job with attempts < max_attempts
failedJob := &domain.Job{
@@ -836,7 +836,7 @@ func TestProcessRenewalJob_NoCertificate(t *testing.T) {
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server")
svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server")
// Create job with non-existent certificate
job := &domain.Job{