mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-09 15:18:58 +00:00
feat(M11c): crypto policy enforcement — CSR validation, MaxTTL caps, key metadata
Enforce certificate profile crypto constraints across all 5 issuance paths (renewal, agent CSR, EST, SCEP). ValidateCSRAgainstProfile() rejects CSRs with key algorithm/size that don't match profile rules. MaxTTL enforcement caps certificate validity per issuer connector (Local CA, Vault, step-ca enforce directly; ACME/DigiCert/Sectigo pass through). Key algorithm and size are now persisted in certificate_versions for audit compliance. 16 new tests (12 service-layer + 4 Local CA connector). Removes hardcoded version number from GUI sidebar. Documentation updated across architecture, features, connectors, and README. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -51,10 +51,11 @@ type RenewalInfoResult struct {
|
||||
|
||||
// IssuanceRequest contains the parameters for issuing a new certificate.
|
||||
type IssuanceRequest struct {
|
||||
CommonName string `json:"common_name"`
|
||||
SANs []string `json:"sans"`
|
||||
CSRPEM string `json:"csr_pem"`
|
||||
EKUs []string `json:"ekus,omitempty"` // e.g., "serverAuth", "clientAuth", "emailProtection"
|
||||
CommonName string `json:"common_name"`
|
||||
SANs []string `json:"sans"`
|
||||
CSRPEM string `json:"csr_pem"`
|
||||
EKUs []string `json:"ekus,omitempty"` // e.g., "serverAuth", "clientAuth", "emailProtection"
|
||||
MaxTTLSeconds int `json:"max_ttl_seconds,omitempty"` // 0 = no cap (use issuer default)
|
||||
}
|
||||
|
||||
// IssuanceResult contains the result of a successful certificate issuance.
|
||||
@@ -69,11 +70,12 @@ type IssuanceResult struct {
|
||||
|
||||
// RenewalRequest contains the parameters for renewing a certificate.
|
||||
type RenewalRequest struct {
|
||||
CommonName string `json:"common_name"`
|
||||
SANs []string `json:"sans"`
|
||||
CSRPEM string `json:"csr_pem"`
|
||||
EKUs []string `json:"ekus,omitempty"` // e.g., "serverAuth", "clientAuth", "emailProtection"
|
||||
OrderID *string `json:"order_id,omitempty"`
|
||||
CommonName string `json:"common_name"`
|
||||
SANs []string `json:"sans"`
|
||||
CSRPEM string `json:"csr_pem"`
|
||||
EKUs []string `json:"ekus,omitempty"` // e.g., "serverAuth", "clientAuth", "emailProtection"
|
||||
MaxTTLSeconds int `json:"max_ttl_seconds,omitempty"` // 0 = no cap (use issuer default)
|
||||
OrderID *string `json:"order_id,omitempty"`
|
||||
}
|
||||
|
||||
// RevocationRequest contains the parameters for revoking a certificate.
|
||||
|
||||
@@ -184,8 +184,8 @@ func (c *Connector) IssueCertificate(ctx context.Context, request issuer.Issuanc
|
||||
return nil, fmt.Errorf("CSR signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Generate certificate with EKUs from request
|
||||
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs, request.EKUs)
|
||||
// Generate certificate with EKUs and MaxTTL from request
|
||||
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs, request.EKUs, request.MaxTTLSeconds)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to generate certificate", "error", err)
|
||||
return nil, fmt.Errorf("certificate generation failed: %w", err)
|
||||
@@ -242,8 +242,8 @@ func (c *Connector) RenewCertificate(ctx context.Context, request issuer.Renewal
|
||||
return nil, fmt.Errorf("CSR signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Generate certificate with EKUs from request
|
||||
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs, request.EKUs)
|
||||
// Generate certificate with EKUs and MaxTTL from request
|
||||
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs, request.EKUs, request.MaxTTLSeconds)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to generate certificate", "error", err)
|
||||
return nil, fmt.Errorf("certificate generation failed: %w", err)
|
||||
@@ -468,7 +468,8 @@ func parsePrivateKey(block *pem.Block) (crypto.Signer, error) {
|
||||
// generateCertificate creates an X.509 certificate signed by the local CA.
|
||||
// It uses the CSR subject and adds any additional SANs from the request.
|
||||
// If ekus is non-empty, those EKUs are used instead of the default serverAuth+clientAuth.
|
||||
func (c *Connector) generateCertificate(csr *x509.CertificateRequest, additionalSANs []string, ekus []string) (*x509.Certificate, string, string, error) {
|
||||
// If maxTTLSeconds > 0, the certificate validity is capped to that duration.
|
||||
func (c *Connector) generateCertificate(csr *x509.CertificateRequest, additionalSANs []string, ekus []string, maxTTLSeconds int) (*x509.Certificate, string, string, error) {
|
||||
// Generate random serial number
|
||||
serialNum, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 159))
|
||||
if err != nil {
|
||||
@@ -512,11 +513,21 @@ func (c *Connector) generateCertificate(csr *x509.CertificateRequest, additional
|
||||
|
||||
// Create certificate template
|
||||
now := time.Now()
|
||||
notAfter := now.AddDate(0, 0, c.config.ValidityDays)
|
||||
|
||||
// Cap validity to MaxTTLSeconds if profile specifies a maximum
|
||||
if maxTTLSeconds > 0 {
|
||||
maxNotAfter := now.Add(time.Duration(maxTTLSeconds) * time.Second)
|
||||
if maxNotAfter.Before(notAfter) {
|
||||
notAfter = maxNotAfter
|
||||
}
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNum,
|
||||
Subject: csr.Subject,
|
||||
NotBefore: now,
|
||||
NotAfter: now.AddDate(0, 0, c.config.ValidityDays),
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: resolvedEKUs,
|
||||
DNSNames: dnsNames,
|
||||
|
||||
@@ -870,6 +870,156 @@ func TestGenerateCRL_SubCA(t *testing.T) {
|
||||
t.Log("SubCA CRL generated successfully")
|
||||
}
|
||||
|
||||
// M11c: MaxTTL enforcement tests
|
||||
|
||||
func TestIssueCertificate_MaxTTL_CapsValidity(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
config := &local.Config{
|
||||
CACommonName: "Test CA",
|
||||
ValidityDays: 365, // would normally be 1 year
|
||||
}
|
||||
connector := local.New(config, logger)
|
||||
|
||||
_, csrPEM, err := generateTestCSR("maxttl.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate CSR: %v", err)
|
||||
}
|
||||
|
||||
// MaxTTLSeconds = 3600 (1 hour) should cap the 365-day validity
|
||||
req := issuer.IssuanceRequest{
|
||||
CommonName: "maxttl.example.com",
|
||||
SANs: []string{"maxttl.example.com"},
|
||||
CSRPEM: csrPEM,
|
||||
MaxTTLSeconds: 3600,
|
||||
}
|
||||
|
||||
result, err := connector.IssueCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
// Cert validity should be ~1 hour, not 365 days
|
||||
duration := result.NotAfter.Sub(result.NotBefore)
|
||||
if duration > 2*time.Hour {
|
||||
t.Errorf("expected validity ≤1h, got %v", duration)
|
||||
}
|
||||
if duration < 30*time.Minute {
|
||||
t.Errorf("expected validity ≥30m, got %v (too short)", duration)
|
||||
}
|
||||
|
||||
t.Logf("MaxTTL capped: validity=%v (NotBefore=%v, NotAfter=%v)", duration, result.NotBefore, result.NotAfter)
|
||||
}
|
||||
|
||||
func TestIssueCertificate_MaxTTL_ZeroMeansNoCap(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
config := &local.Config{
|
||||
CACommonName: "Test CA",
|
||||
ValidityDays: 30,
|
||||
}
|
||||
connector := local.New(config, logger)
|
||||
|
||||
_, csrPEM, err := generateTestCSR("nocap.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate CSR: %v", err)
|
||||
}
|
||||
|
||||
req := issuer.IssuanceRequest{
|
||||
CommonName: "nocap.example.com",
|
||||
SANs: []string{"nocap.example.com"},
|
||||
CSRPEM: csrPEM,
|
||||
MaxTTLSeconds: 0, // no cap
|
||||
}
|
||||
|
||||
result, err := connector.IssueCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
// Should get ~30 days as configured
|
||||
duration := result.NotAfter.Sub(result.NotBefore)
|
||||
if duration < 29*24*time.Hour {
|
||||
t.Errorf("expected ~30 day validity without MaxTTL cap, got %v", duration)
|
||||
}
|
||||
|
||||
t.Logf("No MaxTTL cap: validity=%v", duration)
|
||||
}
|
||||
|
||||
func TestIssueCertificate_MaxTTL_LargerThanValidityDays_NoCap(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
config := &local.Config{
|
||||
CACommonName: "Test CA",
|
||||
ValidityDays: 30,
|
||||
}
|
||||
connector := local.New(config, logger)
|
||||
|
||||
_, csrPEM, err := generateTestCSR("larger.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate CSR: %v", err)
|
||||
}
|
||||
|
||||
// MaxTTL = 365 days, but ValidityDays = 30. The shorter one wins.
|
||||
req := issuer.IssuanceRequest{
|
||||
CommonName: "larger.example.com",
|
||||
SANs: []string{"larger.example.com"},
|
||||
CSRPEM: csrPEM,
|
||||
MaxTTLSeconds: 365 * 24 * 3600, // 365 days
|
||||
}
|
||||
|
||||
result, err := connector.IssueCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
// Should still be ~30 days (ValidityDays wins when shorter)
|
||||
duration := result.NotAfter.Sub(result.NotBefore)
|
||||
if duration > 31*24*time.Hour {
|
||||
t.Errorf("expected ~30 day validity (ValidityDays wins), got %v", duration)
|
||||
}
|
||||
|
||||
t.Logf("MaxTTL larger than ValidityDays: validity=%v (ValidityDays wins)", duration)
|
||||
}
|
||||
|
||||
func TestRenewCertificate_MaxTTL_CapsValidity(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
config := &local.Config{
|
||||
CACommonName: "Test CA",
|
||||
ValidityDays: 365,
|
||||
}
|
||||
connector := local.New(config, logger)
|
||||
|
||||
_, csrPEM, err := generateTestCSR("renew-maxttl.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate CSR: %v", err)
|
||||
}
|
||||
|
||||
req := issuer.RenewalRequest{
|
||||
CommonName: "renew-maxttl.example.com",
|
||||
SANs: []string{"renew-maxttl.example.com"},
|
||||
CSRPEM: csrPEM,
|
||||
MaxTTLSeconds: 7200, // 2 hours
|
||||
}
|
||||
|
||||
result, err := connector.RenewCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("RenewCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
duration := result.NotAfter.Sub(result.NotBefore)
|
||||
if duration > 3*time.Hour {
|
||||
t.Errorf("expected validity ≤2h for renewal MaxTTL, got %v", duration)
|
||||
}
|
||||
|
||||
t.Logf("Renewal MaxTTL capped: validity=%v", duration)
|
||||
}
|
||||
|
||||
func TestSignOCSPResponse_SubCA(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -148,6 +148,14 @@ func (c *Connector) IssueCertificate(ctx context.Context, request issuer.Issuanc
|
||||
"common_name", request.CommonName,
|
||||
"san_count", len(request.SANs))
|
||||
|
||||
// MaxTTLSeconds is advisory for script-based issuers — the sign script controls validity.
|
||||
// Log a warning so operators know the profile TTL cap isn't enforced server-side.
|
||||
if request.MaxTTLSeconds > 0 {
|
||||
c.logger.Warn("MaxTTLSeconds specified but OpenSSL/custom CA delegates signing to external script; TTL cap is advisory only",
|
||||
"max_ttl_seconds", request.MaxTTLSeconds,
|
||||
"common_name", request.CommonName)
|
||||
}
|
||||
|
||||
// Write CSR to a temporary file
|
||||
csrFile, err := c.writeTempFile([]byte(request.CSRPEM), "csr-")
|
||||
if err != nil {
|
||||
|
||||
@@ -201,10 +201,19 @@ func (c *Connector) IssueCertificate(ctx context.Context, request issuer.Issuanc
|
||||
CsrPEM: request.CSRPEM,
|
||||
OTT: ott,
|
||||
}
|
||||
if c.config.ValidityDays > 0 {
|
||||
if c.config.ValidityDays > 0 || request.MaxTTLSeconds > 0 {
|
||||
now := time.Now()
|
||||
signReq.NotBefore = now
|
||||
signReq.NotAfter = now.AddDate(0, 0, c.config.ValidityDays)
|
||||
if c.config.ValidityDays > 0 {
|
||||
signReq.NotAfter = now.AddDate(0, 0, c.config.ValidityDays)
|
||||
}
|
||||
// Cap validity to MaxTTLSeconds if profile specifies a maximum
|
||||
if request.MaxTTLSeconds > 0 {
|
||||
maxNotAfter := now.Add(time.Duration(request.MaxTTLSeconds) * time.Second)
|
||||
if signReq.NotAfter.IsZero() || maxNotAfter.Before(signReq.NotAfter) {
|
||||
signReq.NotAfter = maxNotAfter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(signReq)
|
||||
@@ -266,9 +275,10 @@ func (c *Connector) RenewCertificate(ctx context.Context, request issuer.Renewal
|
||||
"san_count", len(request.SANs))
|
||||
|
||||
return c.IssueCertificate(ctx, issuer.IssuanceRequest{
|
||||
CommonName: request.CommonName,
|
||||
SANs: request.SANs,
|
||||
CSRPEM: request.CSRPEM,
|
||||
CommonName: request.CommonName,
|
||||
SANs: request.SANs,
|
||||
CSRPEM: request.CSRPEM,
|
||||
MaxTTLSeconds: request.MaxTTLSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -160,11 +160,17 @@ func (c *Connector) IssueCertificate(ctx context.Context, request issuer.Issuanc
|
||||
"common_name", request.CommonName,
|
||||
"san_count", len(request.SANs))
|
||||
|
||||
// Determine TTL — cap to MaxTTLSeconds from profile if specified
|
||||
ttl := c.config.TTL
|
||||
if request.MaxTTLSeconds > 0 {
|
||||
ttl = fmt.Sprintf("%ds", request.MaxTTLSeconds)
|
||||
}
|
||||
|
||||
// Build the sign request body
|
||||
signBody := map[string]interface{}{
|
||||
"csr": request.CSRPEM,
|
||||
"common_name": request.CommonName,
|
||||
"ttl": c.config.TTL,
|
||||
"ttl": ttl,
|
||||
}
|
||||
|
||||
if len(request.SANs) > 0 {
|
||||
@@ -267,10 +273,11 @@ func (c *Connector) RenewCertificate(ctx context.Context, request issuer.Renewal
|
||||
"san_count", len(request.SANs))
|
||||
|
||||
return c.IssueCertificate(ctx, issuer.IssuanceRequest{
|
||||
CommonName: request.CommonName,
|
||||
SANs: request.SANs,
|
||||
CSRPEM: request.CSRPEM,
|
||||
EKUs: request.EKUs,
|
||||
CommonName: request.CommonName,
|
||||
SANs: request.SANs,
|
||||
CSRPEM: request.CSRPEM,
|
||||
EKUs: request.EKUs,
|
||||
MaxTTLSeconds: request.MaxTTLSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
|
||||
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, certificate_id, serial_number, not_before, not_after,
|
||||
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
|
||||
FROM certificate_versions
|
||||
WHERE certificate_id = $1
|
||||
ORDER BY created_at DESC
|
||||
@@ -364,11 +364,15 @@ func (r *CertificateRepository) ListVersions(ctx context.Context, certID string)
|
||||
for rows.Next() {
|
||||
var v domain.CertificateVersion
|
||||
var csrPEM sql.NullString
|
||||
var keyAlgo sql.NullString
|
||||
var keySize sql.NullInt64
|
||||
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
|
||||
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &v.CreatedAt); err != nil {
|
||||
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &keyAlgo, &keySize, &v.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
|
||||
}
|
||||
v.CSRPEM = csrPEM.String
|
||||
v.KeyAlgorithm = keyAlgo.String
|
||||
v.KeySize = int(keySize.Int64)
|
||||
versions = append(versions, &v)
|
||||
}
|
||||
|
||||
@@ -388,11 +392,11 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO certificate_versions (
|
||||
id, certificate_id, serial_number, not_before, not_after,
|
||||
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id
|
||||
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
|
||||
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID)
|
||||
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate version: %w", err)
|
||||
@@ -436,16 +440,20 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
|
||||
func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
|
||||
var v domain.CertificateVersion
|
||||
var csrPEM sql.NullString
|
||||
var keyAlgo sql.NullString
|
||||
var keySize sql.NullInt64
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, certificate_id, serial_number, not_before, not_after,
|
||||
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
|
||||
FROM certificate_versions
|
||||
WHERE certificate_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, certID).Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
|
||||
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &v.CreatedAt)
|
||||
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &keyAlgo, &keySize, &v.CreatedAt)
|
||||
v.CSRPEM = csrPEM.String
|
||||
v.KeyAlgorithm = keyAlgo.String
|
||||
v.KeySize = int(keySize.Int64)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest certificate version: %w", err)
|
||||
|
||||
@@ -165,14 +165,29 @@ func (s *AgentService) SubmitCSR(ctx context.Context, agentID string, certID str
|
||||
// Fallback: direct issuer signing (no AwaitingCSR job — ad-hoc CSR submission)
|
||||
connector, ok := s.issuerRegistry.Get(cert.IssuerID)
|
||||
if ok {
|
||||
// Resolve EKUs from the certificate profile if available
|
||||
// Resolve profile for EKU resolution and crypto policy enforcement
|
||||
var ekus []string
|
||||
var profile *domain.CertificateProfile
|
||||
if cert.CertificateProfileID != "" && s.profileRepo != nil {
|
||||
if profile, profileErr := s.profileRepo.Get(ctx, cert.CertificateProfileID); profileErr == nil && profile != nil {
|
||||
if p, profileErr := s.profileRepo.Get(ctx, cert.CertificateProfileID); profileErr == nil && p != nil {
|
||||
profile = p
|
||||
ekus = profile.AllowedEKUs
|
||||
}
|
||||
}
|
||||
result, err := connector.IssueCertificate(ctx, cert.CommonName, cert.SANs, string(csrPEM), ekus)
|
||||
|
||||
// Validate CSR key algorithm/size against profile (crypto policy enforcement)
|
||||
csrInfo, csrErr := ValidateCSRAgainstProfile(string(csrPEM), profile)
|
||||
if csrErr != nil {
|
||||
return fmt.Errorf("CSR validation failed: %w", csrErr)
|
||||
}
|
||||
|
||||
// Resolve MaxTTL from profile
|
||||
var maxTTLSeconds int
|
||||
if profile != nil {
|
||||
maxTTLSeconds = profile.MaxTTLSeconds
|
||||
}
|
||||
|
||||
result, err := connector.IssueCertificate(ctx, cert.CommonName, cert.SANs, string(csrPEM), ekus, maxTTLSeconds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("issuer signing failed: %w", err)
|
||||
}
|
||||
@@ -188,6 +203,10 @@ func (s *AgentService) SubmitCSR(ctx context.Context, agentID string, certID str
|
||||
CSRPEM: string(csrPEM),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if csrInfo != nil {
|
||||
version.KeyAlgorithm = csrInfo.KeyAlgorithm
|
||||
version.KeySize = csrInfo.KeySize
|
||||
}
|
||||
|
||||
if err := s.certRepo.CreateVersion(ctx, version); err != nil {
|
||||
return fmt.Errorf("failed to store certificate version: %w", err)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
|
||||
func TestRegisterAgent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agentRepo := &mockAgentRepo{
|
||||
@@ -484,7 +485,7 @@ func TestSubmitCSR(t *testing.T) {
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
||||
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\ntest-csr\n-----END CERTIFICATE REQUEST-----"
|
||||
csrPEM := generateTestCSR(t, "ECDSA", 256)
|
||||
err := agentService.SubmitCSR(ctx, "agent-001", "cert-001", []byte(csrPEM))
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitCSR failed: %v", err)
|
||||
|
||||
+32
-2
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// ESTService implements the EST (RFC 7030) enrollment protocol.
|
||||
@@ -20,6 +21,7 @@ type ESTService struct {
|
||||
auditService *AuditService
|
||||
logger *slog.Logger
|
||||
profileID string // optional: constrain enrollments to a specific profile
|
||||
profileRepo repository.CertificateProfileRepository
|
||||
}
|
||||
|
||||
// NewESTService creates a new ESTService for the given issuer connector.
|
||||
@@ -37,6 +39,11 @@ func (s *ESTService) SetProfileID(profileID string) {
|
||||
s.profileID = profileID
|
||||
}
|
||||
|
||||
// SetProfileRepo sets the profile repository for crypto policy enforcement during enrollment.
|
||||
func (s *ESTService) SetProfileRepo(repo repository.CertificateProfileRepository) {
|
||||
s.profileRepo = repo
|
||||
}
|
||||
|
||||
// GetCACerts returns the PEM-encoded CA certificate chain for this EST server.
|
||||
// RFC 7030 Section 4.1: /cacerts distributes the current CA certificates.
|
||||
func (s *ESTService) GetCACerts(ctx context.Context) (string, error) {
|
||||
@@ -109,15 +116,38 @@ func (s *ESTService) processEnrollment(ctx context.Context, csrPEM string, audit
|
||||
sans = append(sans, uri.String())
|
||||
}
|
||||
|
||||
// Validate CSR key algorithm/size against profile (crypto policy enforcement)
|
||||
var profile *domain.CertificateProfile
|
||||
var ekus []string
|
||||
if s.profileID != "" && s.profileRepo != nil {
|
||||
if p, profileErr := s.profileRepo.Get(ctx, s.profileID); profileErr == nil && p != nil {
|
||||
profile = p
|
||||
ekus = profile.AllowedEKUs
|
||||
}
|
||||
}
|
||||
if _, csrErr := ValidateCSRAgainstProfile(csrPEM, profile); csrErr != nil {
|
||||
s.logger.Error("EST enrollment rejected: crypto policy violation",
|
||||
"action", auditAction,
|
||||
"common_name", commonName,
|
||||
"error", csrErr)
|
||||
return nil, fmt.Errorf("EST enrollment rejected: %w", csrErr)
|
||||
}
|
||||
|
||||
s.logger.Info("EST enrollment request",
|
||||
"action", auditAction,
|
||||
"common_name", commonName,
|
||||
"sans", strings.Join(sans, ","),
|
||||
"issuer", s.issuerID)
|
||||
|
||||
// Resolve MaxTTL from profile
|
||||
var maxTTLSeconds int
|
||||
if profile != nil {
|
||||
maxTTLSeconds = profile.MaxTTLSeconds
|
||||
}
|
||||
|
||||
// Issue the certificate via the configured issuer connector
|
||||
// EST enrollments use default EKUs (nil = serverAuth + clientAuth fallback in connector)
|
||||
result, err := s.issuer.IssueCertificate(ctx, commonName, sans, csrPEM, nil)
|
||||
// EST enrollments use profile EKUs if available, otherwise default (serverAuth + clientAuth fallback)
|
||||
result, err := s.issuer.IssueCertificate(ctx, commonName, sans, csrPEM, ekus, maxTTLSeconds)
|
||||
if err != nil {
|
||||
s.logger.Error("EST enrollment failed",
|
||||
"action", auditAction,
|
||||
|
||||
@@ -20,12 +20,13 @@ func NewIssuerConnectorAdapter(c issuer.Connector) IssuerConnector {
|
||||
|
||||
// IssueCertificate delegates to the underlying connector's IssueCertificate method,
|
||||
// translating between service-layer and connector-layer types.
|
||||
func (a *IssuerConnectorAdapter) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error) {
|
||||
func (a *IssuerConnectorAdapter) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
result, err := a.connector.IssueCertificate(ctx, issuer.IssuanceRequest{
|
||||
CommonName: commonName,
|
||||
SANs: sans,
|
||||
CSRPEM: csrPEM,
|
||||
EKUs: ekus,
|
||||
CommonName: commonName,
|
||||
SANs: sans,
|
||||
CSRPEM: csrPEM,
|
||||
EKUs: ekus,
|
||||
MaxTTLSeconds: maxTTLSeconds,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -41,12 +42,13 @@ func (a *IssuerConnectorAdapter) IssueCertificate(ctx context.Context, commonNam
|
||||
|
||||
// RenewCertificate delegates to the underlying connector's RenewCertificate method,
|
||||
// translating between service-layer and connector-layer types.
|
||||
func (a *IssuerConnectorAdapter) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error) {
|
||||
func (a *IssuerConnectorAdapter) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
result, err := a.connector.RenewCertificate(ctx, issuer.RenewalRequest{
|
||||
CommonName: commonName,
|
||||
SANs: sans,
|
||||
CSRPEM: csrPEM,
|
||||
EKUs: ekus,
|
||||
CommonName: commonName,
|
||||
SANs: sans,
|
||||
CSRPEM: csrPEM,
|
||||
EKUs: ekus,
|
||||
MaxTTLSeconds: maxTTLSeconds,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -140,7 +140,7 @@ func TestIssuerConnectorAdapter_IssueCertificate_Success(t *testing.T) {
|
||||
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
result, err := adapter.IssueCertificate(ctx, "example.com", []string{"www.example.com"}, "-----BEGIN CERTIFICATE REQUEST-----\nCSR\n-----END CERTIFICATE REQUEST-----", nil)
|
||||
result, err := adapter.IssueCertificate(ctx, "example.com", []string{"www.example.com"}, "-----BEGIN CERTIFICATE REQUEST-----\nCSR\n-----END CERTIFICATE REQUEST-----", nil, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("IssueCertificate failed: %v", err)
|
||||
@@ -177,7 +177,7 @@ func TestIssuerConnectorAdapter_IssueCertificate_Error(t *testing.T) {
|
||||
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
result, err := adapter.IssueCertificate(ctx, "example.com", []string{}, "csr", nil)
|
||||
result, err := adapter.IssueCertificate(ctx, "example.com", []string{}, "csr", nil, 0)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
@@ -211,7 +211,7 @@ func TestIssuerConnectorAdapter_IssueCertificate_RequestTranslation(t *testing.T
|
||||
sans := []string{"www.test.example.com", "api.test.example.com"}
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nCSR\n-----END CERTIFICATE REQUEST-----"
|
||||
|
||||
_, err := adapter.IssueCertificate(ctx, commonName, sans, csrPEM, nil)
|
||||
_, err := adapter.IssueCertificate(ctx, commonName, sans, csrPEM, nil, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("IssueCertificate failed: %v", err)
|
||||
@@ -261,7 +261,7 @@ func TestIssuerConnectorAdapter_RenewCertificate_Success(t *testing.T) {
|
||||
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
result, err := adapter.RenewCertificate(ctx, "example.com", []string{"www.example.com"}, "-----BEGIN CERTIFICATE REQUEST-----\nCSR\n-----END CERTIFICATE REQUEST-----", nil)
|
||||
result, err := adapter.RenewCertificate(ctx, "example.com", []string{"www.example.com"}, "-----BEGIN CERTIFICATE REQUEST-----\nCSR\n-----END CERTIFICATE REQUEST-----", nil, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RenewCertificate failed: %v", err)
|
||||
@@ -298,7 +298,7 @@ func TestIssuerConnectorAdapter_RenewCertificate_Error(t *testing.T) {
|
||||
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
result, err := adapter.RenewCertificate(ctx, "example.com", []string{}, "csr", nil)
|
||||
result, err := adapter.RenewCertificate(ctx, "example.com", []string{}, "csr", nil, 0)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
@@ -332,7 +332,7 @@ func TestIssuerConnectorAdapter_RenewCertificate_RequestTranslation(t *testing.T
|
||||
sans := []string{"www.renew.example.com"}
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nRENEW-CSR\n-----END CERTIFICATE REQUEST-----"
|
||||
|
||||
_, err := adapter.RenewCertificate(ctx, commonName, sans, csrPEM, nil)
|
||||
_, err := adapter.RenewCertificate(ctx, commonName, sans, csrPEM, nil, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RenewCertificate failed: %v", err)
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// m11cProfileRepo wraps the existing mockProfileRepo from profile_test.go with AddProfile helper.
|
||||
// We reuse the existing mock and just create instances with pre-populated profiles.
|
||||
func newM11cProfileRepo() *mockProfileRepo {
|
||||
return &mockProfileRepo{
|
||||
profiles: make(map[string]*domain.CertificateProfile),
|
||||
}
|
||||
}
|
||||
|
||||
// --- EST Crypto Policy Enforcement Tests ---
|
||||
|
||||
func TestESTService_CryptoValidation_RejectsWeakKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewESTService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
|
||||
|
||||
// Profile requiring ECDSA P-384 minimum
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-high-sec"] = &domain.CertificateProfile{
|
||||
ID: "prof-high-sec",
|
||||
Name: "High Security",
|
||||
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
|
||||
{Algorithm: "ECDSA", MinSize: 384},
|
||||
},
|
||||
}
|
||||
svc.SetProfileID("prof-high-sec")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
// P-256 CSR should be rejected by P-384 minimum
|
||||
csrPEM := generateCSRPEM(t, "weak.example.com", nil)
|
||||
|
||||
_, err := svc.SimpleEnroll(context.Background(), csrPEM)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for ECDSA P-256 against P-384 minimum")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "EST enrollment rejected") {
|
||||
t.Errorf("expected 'EST enrollment rejected' in error, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not match any allowed algorithm") {
|
||||
t.Errorf("expected algorithm mismatch message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestESTService_CryptoValidation_AcceptsStrongKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
svc := NewESTService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
|
||||
|
||||
// Profile allows P-256+
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-standard"] = &domain.CertificateProfile{
|
||||
ID: "prof-standard",
|
||||
Name: "Standard TLS",
|
||||
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
|
||||
{Algorithm: "ECDSA", MinSize: 256},
|
||||
},
|
||||
}
|
||||
svc.SetProfileID("prof-standard")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
csrPEM := generateCSRPEM(t, "strong.example.com", nil)
|
||||
|
||||
result, err := svc.SimpleEnroll(context.Background(), csrPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success for ECDSA P-256 against P-256 minimum: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestESTService_MaxTTL_ForwardedToIssuer(t *testing.T) {
|
||||
// Track what the mock issuer receives
|
||||
var capturedMaxTTL int
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
// Override IssueCertificate to capture maxTTLSeconds
|
||||
// We'll use a capturing mock instead
|
||||
capturingMock := &capturingIssuerConnector{}
|
||||
|
||||
svc := NewESTService("iss-local", capturingMock, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
|
||||
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-short"] = &domain.CertificateProfile{
|
||||
ID: "prof-short",
|
||||
Name: "Short Lived",
|
||||
MaxTTLSeconds: 3600, // 1 hour
|
||||
}
|
||||
svc.SetProfileID("prof-short")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
csrPEM := generateCSRPEM(t, "short.example.com", nil)
|
||||
|
||||
_, err := svc.SimpleEnroll(context.Background(), csrPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
capturedMaxTTL = capturingMock.lastMaxTTLSeconds
|
||||
if capturedMaxTTL != 3600 {
|
||||
t.Errorf("expected maxTTLSeconds=3600 forwarded to issuer, got %d", capturedMaxTTL)
|
||||
}
|
||||
|
||||
_ = mockIssuer // suppress unused
|
||||
}
|
||||
|
||||
// --- SCEP Crypto Policy Enforcement Tests ---
|
||||
|
||||
func TestSCEPService_CryptoValidation_RejectsWeakKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
|
||||
// Profile requiring ECDSA P-384 minimum
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-high-sec"] = &domain.CertificateProfile{
|
||||
ID: "prof-high-sec",
|
||||
Name: "High Security",
|
||||
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
|
||||
{Algorithm: "ECDSA", MinSize: 384},
|
||||
},
|
||||
}
|
||||
svc.SetProfileID("prof-high-sec")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
// P-256 CSR should be rejected
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", nil)
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-001")
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for ECDSA P-256 against P-384 minimum")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "SCEP enrollment rejected") {
|
||||
t.Errorf("expected 'SCEP enrollment rejected' in error, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not match any allowed algorithm") {
|
||||
t.Errorf("expected algorithm mismatch message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCEPService_CryptoValidation_AcceptsStrongKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-standard"] = &domain.CertificateProfile{
|
||||
ID: "prof-standard",
|
||||
Name: "Standard TLS",
|
||||
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
|
||||
{Algorithm: "ECDSA", MinSize: 256},
|
||||
},
|
||||
}
|
||||
svc.SetProfileID("prof-standard")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device-ok.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-002")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCEPService_MaxTTL_ForwardedToIssuer(t *testing.T) {
|
||||
capturingMock := &capturingIssuerConnector{}
|
||||
|
||||
svc := NewSCEPService("iss-local", capturingMock, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-device"] = &domain.CertificateProfile{
|
||||
ID: "prof-device",
|
||||
Name: "Device Cert",
|
||||
MaxTTLSeconds: 86400, // 24 hours
|
||||
}
|
||||
svc.SetProfileID("prof-device")
|
||||
svc.SetProfileRepo(profileRepo)
|
||||
|
||||
csrPEM := generateCSRPEM(t, "mdm-device.example.com", nil)
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-003")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if capturingMock.lastMaxTTLSeconds != 86400 {
|
||||
t.Errorf("expected maxTTLSeconds=86400 forwarded to issuer, got %d", capturingMock.lastMaxTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Adapter MaxTTL Forwarding Tests ---
|
||||
|
||||
func TestIssuerConnectorAdapter_IssueCertificate_MaxTTLForwarded(t *testing.T) {
|
||||
mock := &mockConnectorLayerIssuer{}
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
_, err := adapter.IssueCertificate(context.Background(), "test.example.com", nil, "csr", nil, 7200)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if mock.lastIssueReq == nil {
|
||||
t.Fatal("expected request to be recorded")
|
||||
}
|
||||
if mock.lastIssueReq.MaxTTLSeconds != 7200 {
|
||||
t.Errorf("expected MaxTTLSeconds=7200, got %d", mock.lastIssueReq.MaxTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssuerConnectorAdapter_RenewCertificate_MaxTTLForwarded(t *testing.T) {
|
||||
mock := &mockConnectorLayerIssuer{}
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
_, err := adapter.RenewCertificate(context.Background(), "renew.example.com", nil, "csr", nil, 14400)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if mock.lastRenewReq == nil {
|
||||
t.Fatal("expected request to be recorded")
|
||||
}
|
||||
if mock.lastRenewReq.MaxTTLSeconds != 14400 {
|
||||
t.Errorf("expected MaxTTLSeconds=14400, got %d", mock.lastRenewReq.MaxTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssuerConnectorAdapter_IssueCertificate_ZeroMaxTTL(t *testing.T) {
|
||||
mock := &mockConnectorLayerIssuer{}
|
||||
adapter := NewIssuerConnectorAdapter(mock)
|
||||
|
||||
_, err := adapter.IssueCertificate(context.Background(), "test.example.com", nil, "csr", nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if mock.lastIssueReq.MaxTTLSeconds != 0 {
|
||||
t.Errorf("expected MaxTTLSeconds=0 (no cap), got %d", mock.lastIssueReq.MaxTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CreateVersion Key Metadata Persistence Tests ---
|
||||
|
||||
func TestCreateVersion_KeyMetadata_Persisted(t *testing.T) {
|
||||
certRepo := newMockCertificateRepository()
|
||||
|
||||
version := &domain.CertificateVersion{
|
||||
ID: "ver-001",
|
||||
CertificateID: "cert-001",
|
||||
SerialNumber: "serial-001",
|
||||
PEMChain: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyAlgorithm: "ECDSA",
|
||||
KeySize: 256,
|
||||
}
|
||||
|
||||
err := certRepo.CreateVersion(context.Background(), version)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateVersion failed: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify key metadata was stored
|
||||
versions, err := certRepo.ListVersions(context.Background(), "cert-001")
|
||||
if err != nil {
|
||||
t.Fatalf("ListVersions failed: %v", err)
|
||||
}
|
||||
if len(versions) != 1 {
|
||||
t.Fatalf("expected 1 version, got %d", len(versions))
|
||||
}
|
||||
if versions[0].KeyAlgorithm != "ECDSA" {
|
||||
t.Errorf("expected KeyAlgorithm=ECDSA, got %s", versions[0].KeyAlgorithm)
|
||||
}
|
||||
if versions[0].KeySize != 256 {
|
||||
t.Errorf("expected KeySize=256, got %d", versions[0].KeySize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateVersion_RSAKeyMetadata_Persisted(t *testing.T) {
|
||||
certRepo := newMockCertificateRepository()
|
||||
|
||||
version := &domain.CertificateVersion{
|
||||
ID: "ver-002",
|
||||
CertificateID: "cert-002",
|
||||
SerialNumber: "serial-002",
|
||||
PEMChain: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyAlgorithm: "RSA",
|
||||
KeySize: 4096,
|
||||
}
|
||||
|
||||
err := certRepo.CreateVersion(context.Background(), version)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateVersion failed: %v", err)
|
||||
}
|
||||
|
||||
versions, err := certRepo.ListVersions(context.Background(), "cert-002")
|
||||
if err != nil {
|
||||
t.Fatalf("ListVersions failed: %v", err)
|
||||
}
|
||||
if versions[0].KeyAlgorithm != "RSA" {
|
||||
t.Errorf("expected KeyAlgorithm=RSA, got %s", versions[0].KeyAlgorithm)
|
||||
}
|
||||
if versions[0].KeySize != 4096 {
|
||||
t.Errorf("expected KeySize=4096, got %d", versions[0].KeySize)
|
||||
}
|
||||
}
|
||||
|
||||
// --- EST/SCEP without profile repo (graceful passthrough) ---
|
||||
|
||||
func TestESTService_NoProfileRepo_PassesThrough(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewESTService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})))
|
||||
svc.SetProfileID("nonexistent-profile")
|
||||
// Deliberately NOT calling SetProfileRepo — should pass through without validation
|
||||
|
||||
csrPEM := generateCSRPEM(t, "no-profile.example.com", nil)
|
||||
|
||||
result, err := svc.SimpleEnroll(context.Background(), csrPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success when no profile repo set: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCEPService_NoProfileRepo_PassesThrough(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
svc.SetProfileID("nonexistent-profile")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "no-profile-scep.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-004")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success when no profile repo set: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
// --- capturingIssuerConnector captures maxTTLSeconds for verification ---
|
||||
|
||||
type capturingIssuerConnector struct {
|
||||
lastMaxTTLSeconds int
|
||||
lastEKUs []string
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
c.lastMaxTTLSeconds = maxTTLSeconds
|
||||
c.lastEKUs = ekus
|
||||
now := time.Now()
|
||||
return &IssuanceResult{
|
||||
Serial: "test-serial",
|
||||
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
|
||||
NotBefore: now,
|
||||
NotAfter: now.AddDate(1, 0, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
return c.IssueCertificate(ctx, commonName, sans, csrPEM, ekus, maxTTLSeconds)
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) RevokeCertificate(ctx context.Context, serial string, reason string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) GenerateCRL(ctx context.Context, entries []CRLEntry) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) SignOCSPResponse(ctx context.Context, req OCSPSignRequest) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) GetCACertPEM(ctx context.Context) (string, error) {
|
||||
return "-----BEGIN CERTIFICATE-----\nmock-ca\n-----END CERTIFICATE-----", nil
|
||||
}
|
||||
|
||||
func (c *capturingIssuerConnector) GetRenewalInfo(ctx context.Context, certPEM string) (*RenewalInfoResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -43,9 +43,11 @@ func (s *RenewalService) SetTargetRepo(repo repository.TargetRepository) {
|
||||
// inversion. Use IssuerConnectorAdapter to bridge between the two.
|
||||
type IssuerConnector interface {
|
||||
// IssueCertificate issues a new certificate using the provided CSR PEM.
|
||||
IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error)
|
||||
// maxTTLSeconds caps the certificate validity period (0 = no cap, use issuer default).
|
||||
IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error)
|
||||
// RenewCertificate renews a certificate using the provided CSR PEM.
|
||||
RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error)
|
||||
// maxTTLSeconds caps the certificate validity period (0 = no cap, use issuer default).
|
||||
RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error)
|
||||
// RevokeCertificate revokes a certificate by serial number with an optional reason.
|
||||
RevokeCertificate(ctx context.Context, serial string, reason string) error
|
||||
// GenerateCRL generates a DER-encoded X.509 CRL from the given revocation entries.
|
||||
@@ -444,16 +446,18 @@ func (s *RenewalService) processRenewalServerKeygen(ctx context.Context, job *do
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
|
||||
}))
|
||||
|
||||
// Resolve EKUs from the certificate profile
|
||||
// Resolve EKUs and MaxTTL from the certificate profile
|
||||
var ekus []string
|
||||
var maxTTLSeconds int
|
||||
if cert.CertificateProfileID != "" && s.profileRepo != nil {
|
||||
if profile, profileErr := s.profileRepo.Get(ctx, cert.CertificateProfileID); profileErr == nil && profile != nil {
|
||||
ekus = profile.AllowedEKUs
|
||||
maxTTLSeconds = profile.MaxTTLSeconds
|
||||
}
|
||||
}
|
||||
|
||||
// Call issuer connector to renew
|
||||
result, err := connector.RenewCertificate(ctx, cert.CommonName, cert.SANs, csrPEM, ekus)
|
||||
result, err := connector.RenewCertificate(ctx, cert.CommonName, cert.SANs, csrPEM, ekus, maxTTLSeconds)
|
||||
if err != nil {
|
||||
s.failJob(ctx, job, fmt.Sprintf("issuer renewal failed: %v", err))
|
||||
if notifErr := s.notificationSvc.SendRenewalNotification(ctx, cert, false, err); notifErr != nil {
|
||||
@@ -560,14 +564,18 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai
|
||||
return fmt.Errorf("failed to update job status: %w", err)
|
||||
}
|
||||
|
||||
// Resolve EKUs from the certificate profile (for S/MIME, email certs, etc.)
|
||||
// Resolve EKUs and MaxTTL from the certificate profile (for S/MIME, email certs, etc.)
|
||||
var ekus []string
|
||||
if profile != nil && len(profile.AllowedEKUs) > 0 {
|
||||
ekus = profile.AllowedEKUs
|
||||
var maxTTLSeconds int
|
||||
if profile != nil {
|
||||
if len(profile.AllowedEKUs) > 0 {
|
||||
ekus = profile.AllowedEKUs
|
||||
}
|
||||
maxTTLSeconds = profile.MaxTTLSeconds
|
||||
}
|
||||
|
||||
// Sign the agent-submitted CSR via issuer
|
||||
result, err := connector.RenewCertificate(ctx, cert.CommonName, cert.SANs, csrPEM, ekus)
|
||||
result, err := connector.RenewCertificate(ctx, cert.CommonName, cert.SANs, csrPEM, ekus, maxTTLSeconds)
|
||||
if err != nil {
|
||||
s.failJob(ctx, job, fmt.Sprintf("issuer signing failed: %v", err))
|
||||
if notifErr := s.notificationSvc.SendRenewalNotification(ctx, cert, false, err); notifErr != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// SCEPService implements the SCEP (RFC 8894) enrollment protocol.
|
||||
@@ -20,6 +21,7 @@ type SCEPService struct {
|
||||
auditService *AuditService
|
||||
logger *slog.Logger
|
||||
profileID string // optional: constrain enrollments to a specific profile
|
||||
profileRepo repository.CertificateProfileRepository
|
||||
challengePassword string // shared secret for enrollment authentication
|
||||
}
|
||||
|
||||
@@ -39,6 +41,11 @@ func (s *SCEPService) SetProfileID(profileID string) {
|
||||
s.profileID = profileID
|
||||
}
|
||||
|
||||
// SetProfileRepo sets the profile repository for crypto policy enforcement during enrollment.
|
||||
func (s *SCEPService) SetProfileRepo(repo repository.CertificateProfileRepository) {
|
||||
s.profileRepo = repo
|
||||
}
|
||||
|
||||
// GetCACaps returns the capabilities of this SCEP server.
|
||||
// RFC 8894 Section 3.5.2: GetCACaps returns a list of capabilities, one per line.
|
||||
func (s *SCEPService) GetCACaps(ctx context.Context) string {
|
||||
@@ -111,6 +118,24 @@ func (s *SCEPService) processEnrollment(ctx context.Context, csrPEM string, tran
|
||||
sans = append(sans, uri.String())
|
||||
}
|
||||
|
||||
// Validate CSR key algorithm/size against profile (crypto policy enforcement)
|
||||
var profile *domain.CertificateProfile
|
||||
var ekus []string
|
||||
if s.profileID != "" && s.profileRepo != nil {
|
||||
if p, profileErr := s.profileRepo.Get(ctx, s.profileID); profileErr == nil && p != nil {
|
||||
profile = p
|
||||
ekus = profile.AllowedEKUs
|
||||
}
|
||||
}
|
||||
if _, csrErr := ValidateCSRAgainstProfile(csrPEM, profile); csrErr != nil {
|
||||
s.logger.Error("SCEP enrollment rejected: crypto policy violation",
|
||||
"action", auditAction,
|
||||
"common_name", commonName,
|
||||
"transaction_id", transactionID,
|
||||
"error", csrErr)
|
||||
return nil, fmt.Errorf("SCEP enrollment rejected: %w", csrErr)
|
||||
}
|
||||
|
||||
s.logger.Info("SCEP enrollment request",
|
||||
"action", auditAction,
|
||||
"common_name", commonName,
|
||||
@@ -118,9 +143,15 @@ func (s *SCEPService) processEnrollment(ctx context.Context, csrPEM string, tran
|
||||
"transaction_id", transactionID,
|
||||
"issuer", s.issuerID)
|
||||
|
||||
// Resolve MaxTTL from profile
|
||||
var maxTTLSeconds int
|
||||
if profile != nil {
|
||||
maxTTLSeconds = profile.MaxTTLSeconds
|
||||
}
|
||||
|
||||
// Issue the certificate via the configured issuer connector
|
||||
// SCEP enrollments use default EKUs (nil = serverAuth + clientAuth fallback in connector)
|
||||
result, err := s.issuer.IssueCertificate(ctx, commonName, sans, csrPEM, nil)
|
||||
// SCEP enrollments use profile EKUs if available, otherwise default (serverAuth + clientAuth fallback)
|
||||
result, err := s.issuer.IssueCertificate(ctx, commonName, sans, csrPEM, ekus, maxTTLSeconds)
|
||||
if err != nil {
|
||||
s.logger.Error("SCEP enrollment failed",
|
||||
"action", auditAction,
|
||||
|
||||
@@ -713,7 +713,7 @@ type mockIssuerConnector struct {
|
||||
getRenewalInfoErr error
|
||||
}
|
||||
|
||||
func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error) {
|
||||
func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
@@ -730,11 +730,11 @@ func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName s
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerConnector) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string) (*IssuanceResult, error) {
|
||||
func (m *mockIssuerConnector) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string, ekus []string, maxTTLSeconds int) (*IssuanceResult, error) {
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
return m.IssueCertificate(ctx, commonName, sans, csrPEM, ekus)
|
||||
return m.IssueCertificate(ctx, commonName, sans, csrPEM, ekus, maxTTLSeconds)
|
||||
}
|
||||
|
||||
func (m *mockIssuerConnector) RevokeCertificate(ctx context.Context, serial string, reason string) error {
|
||||
|
||||
Reference in New Issue
Block a user