feat: M11a — certificate profiles, crypto policy enforcement, short-lived cert expiry

Add certificate profiles as named enrollment templates that control allowed
key algorithms, max TTL, permitted EKUs, required SAN patterns, and optional
SPIFFE URI SANs. CSR submissions are validated against profile rules at
signing time (key type + minimum size). Short-lived certs (TTL < 1 hour)
auto-expire via a new scheduler loop — expiry acts as revocation, no
CRL/OCSP needed.

New files:
- Migration 000003: certificate_profiles table, FK columns on
  managed_certificates/renewal_policies, key metadata on certificate_versions
- domain/profile.go: CertificateProfile + KeyAlgorithmRule structs
- repository/postgres/profile.go: full CRUD with JSONB marshaling
- service/profile.go: ProfileService with validation + audit logging
- service/crypto_validation.go: CSR-against-profile validation (RSA/ECDSA/Ed25519)
- handler/profiles.go: 5 HTTP endpoints under /api/v1/profiles
- web/src/pages/ProfilesPage.tsx: profiles management page

Modified:
- renewal.go: CSR validation in CompleteAgentCSRRenewal, ExpireShortLivedCertificates
- scheduler.go: 30s short-lived expiry check loop
- certificate.go (repo): nullable profile FK, key metadata on versions
- main.go: profile repo/service/handler wiring, 8-param NewRenewalService
- router.go: 12-param RegisterHandlers with profile routes
- seed_demo.sql: 4 demo profiles (standard, mtls, short-lived, high-security)
- Frontend: types, API client, routing, sidebar nav

Tests: 40 new tests across handler (15), service (13), crypto validation (12)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shankar
2026-03-20 20:39:49 -04:00
parent 5dc34bde20
commit 1ef16984eb
27 changed files with 2399 additions and 71 deletions
+37 -20
View File
@@ -85,7 +85,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
offset := (filter.Page - 1) * filter.PerPage
query := fmt.Sprintf(`
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
%s
ORDER BY created_at DESC
@@ -120,7 +120,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
WHERE id = $1
`, id)
@@ -147,14 +147,20 @@ func (r *CertificateRepository) Create(ctx context.Context, cert *domain.Managed
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO managed_certificates (
id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
RETURNING id
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, cert.Status, cert.ExpiresAt,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID,
cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
if err != nil {
@@ -171,6 +177,11 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET
name = $1,
@@ -180,15 +191,16 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed
owner_id = $5,
team_id = $6,
issuer_id = $7,
status = $8,
expires_at = $9,
tags = $10,
last_renewal_at = $11,
last_deployment_at = $12,
updated_at = $13
WHERE id = $14
certificate_profile_id = $8,
status = $9,
expires_at = $10,
tags = $11,
last_renewal_at = $12,
last_deployment_at = $13,
updated_at = $14
WHERE id = $15
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt,
cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
if err != nil {
@@ -233,7 +245,7 @@ func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, created_at
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
@@ -248,7 +260,7 @@ func (r *CertificateRepository) ListVersions(ctx context.Context, certID string)
for rows.Next() {
var v domain.CertificateVersion
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil {
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
}
versions = append(versions, &v)
@@ -270,11 +282,11 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma
err := r.db.QueryRowContext(ctx, `
INSERT INTO certificate_versions (
id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID)
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
if err != nil {
return fmt.Errorf("failed to create certificate version: %w", err)
@@ -287,7 +299,7 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
WHERE expires_at < $1 AND status != $2
ORDER BY expires_at ASC
@@ -321,10 +333,12 @@ func scanCertificate(scanner interface {
var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
err := scanner.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
@@ -332,6 +346,9 @@ func scanCertificate(scanner interface {
}
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
+226
View File
@@ -0,0 +1,226 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/shankar0123/certctl/internal/domain"
)
// ProfileRepository implements repository.CertificateProfileRepository
type ProfileRepository struct {
db *sql.DB
}
// NewProfileRepository creates a new ProfileRepository
func NewProfileRepository(db *sql.DB) *ProfileRepository {
return &ProfileRepository{db: db}
}
// List returns all certificate profiles
func (r *ProfileRepository) List(ctx context.Context) ([]*domain.CertificateProfile, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds,
allowed_ekus, required_san_patterns, spiffe_uri_pattern,
allow_short_lived, enabled, created_at, updated_at
FROM certificate_profiles
ORDER BY created_at DESC
`)
if err != nil {
return nil, fmt.Errorf("failed to query profiles: %w", err)
}
defer rows.Close()
var profiles []*domain.CertificateProfile
for rows.Next() {
p, err := scanProfile(rows)
if err != nil {
return nil, err
}
profiles = append(profiles, p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating profile rows: %w", err)
}
return profiles, nil
}
// Get retrieves a certificate profile by ID
func (r *ProfileRepository) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds,
allowed_ekus, required_san_patterns, spiffe_uri_pattern,
allow_short_lived, enabled, created_at, updated_at
FROM certificate_profiles
WHERE id = $1
`, id)
p, err := scanProfile(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("profile not found")
}
return nil, fmt.Errorf("failed to query profile: %w", err)
}
return p, nil
}
// Create stores a new certificate profile
func (r *ProfileRepository) Create(ctx context.Context, profile *domain.CertificateProfile) error {
if profile.ID == "" {
profile.ID = uuid.New().String()
}
if profile.CreatedAt.IsZero() {
profile.CreatedAt = time.Now()
}
if profile.UpdatedAt.IsZero() {
profile.UpdatedAt = time.Now()
}
algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms)
if err != nil {
return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err)
}
ekuJSON, err := json.Marshal(profile.AllowedEKUs)
if err != nil {
return fmt.Errorf("failed to marshal allowed_ekus: %w", err)
}
sanJSON, err := json.Marshal(profile.RequiredSANPatterns)
if err != nil {
return fmt.Errorf("failed to marshal required_san_patterns: %w", err)
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO certificate_profiles (
id, name, description, allowed_key_algorithms, max_ttl_seconds,
allowed_ekus, required_san_patterns, spiffe_uri_pattern,
allow_short_lived, enabled, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id
`, profile.ID, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds,
ekuJSON, sanJSON, profile.SPIFFEURIPattern,
profile.AllowShortLived, profile.Enabled, profile.CreatedAt, profile.UpdatedAt).Scan(&profile.ID)
if err != nil {
return fmt.Errorf("failed to create profile: %w", err)
}
return nil
}
// Update modifies an existing certificate profile
func (r *ProfileRepository) Update(ctx context.Context, profile *domain.CertificateProfile) error {
profile.UpdatedAt = time.Now()
algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms)
if err != nil {
return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err)
}
ekuJSON, err := json.Marshal(profile.AllowedEKUs)
if err != nil {
return fmt.Errorf("failed to marshal allowed_ekus: %w", err)
}
sanJSON, err := json.Marshal(profile.RequiredSANPatterns)
if err != nil {
return fmt.Errorf("failed to marshal required_san_patterns: %w", err)
}
result, err := r.db.ExecContext(ctx, `
UPDATE certificate_profiles SET
name = $1,
description = $2,
allowed_key_algorithms = $3,
max_ttl_seconds = $4,
allowed_ekus = $5,
required_san_patterns = $6,
spiffe_uri_pattern = $7,
allow_short_lived = $8,
enabled = $9,
updated_at = $10
WHERE id = $11
`, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds,
ekuJSON, sanJSON, profile.SPIFFEURIPattern,
profile.AllowShortLived, profile.Enabled, profile.UpdatedAt, profile.ID)
if err != nil {
return fmt.Errorf("failed to update profile: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("profile not found")
}
return nil
}
// Delete removes a certificate profile
func (r *ProfileRepository) Delete(ctx context.Context, id string) error {
result, err := r.db.ExecContext(ctx, "DELETE FROM certificate_profiles WHERE id = $1", id)
if err != nil {
return fmt.Errorf("failed to delete profile: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("profile not found")
}
return nil
}
// scanProfile scans a certificate profile from a row or rows
func scanProfile(scanner interface {
Scan(...interface{}) error
}) (*domain.CertificateProfile, error) {
var p domain.CertificateProfile
var algJSON, ekuJSON, sanJSON []byte
err := scanner.Scan(
&p.ID, &p.Name, &p.Description, &algJSON, &p.MaxTTLSeconds,
&ekuJSON, &sanJSON, &p.SPIFFEURIPattern,
&p.AllowShortLived, &p.Enabled, &p.CreatedAt, &p.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan profile: %w", err)
}
if len(algJSON) > 0 {
if err := json.Unmarshal(algJSON, &p.AllowedKeyAlgorithms); err != nil {
return nil, fmt.Errorf("failed to unmarshal allowed_key_algorithms: %w", err)
}
} else {
p.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms()
}
if len(ekuJSON) > 0 {
if err := json.Unmarshal(ekuJSON, &p.AllowedEKUs); err != nil {
return nil, fmt.Errorf("failed to unmarshal allowed_ekus: %w", err)
}
} else {
p.AllowedEKUs = domain.DefaultEKUs()
}
if len(sanJSON) > 0 {
if err := json.Unmarshal(sanJSON, &p.RequiredSANPatterns); err != nil {
return nil, fmt.Errorf("failed to unmarshal required_san_patterns: %w", err)
}
}
return &p, nil
}