Files
certctl/internal/repository/postgres/certificate.go
T
shankar0123 a579a84c7f feat: M11a — certificate profiles, crypto policy enforcement, short-lived cert expiry
Add certificate profiles as named enrollment templates that control allowed
key algorithms, max TTL, permitted EKUs, required SAN patterns, and optional
SPIFFE URI SANs. CSR submissions are validated against profile rules at
signing time (key type + minimum size). Short-lived certs (TTL < 1 hour)
auto-expire via a new scheduler loop — expiry acts as revocation, no
CRL/OCSP needed.

New files:
- Migration 000003: certificate_profiles table, FK columns on
  managed_certificates/renewal_policies, key metadata on certificate_versions
- domain/profile.go: CertificateProfile + KeyAlgorithmRule structs
- repository/postgres/profile.go: full CRUD with JSONB marshaling
- service/profile.go: ProfileService with validation + audit logging
- service/crypto_validation.go: CSR-against-profile validation (RSA/ECDSA/Ed25519)
- handler/profiles.go: 5 HTTP endpoints under /api/v1/profiles
- web/src/pages/ProfilesPage.tsx: profiles management page

Modified:
- renewal.go: CSR validation in CompleteAgentCSRRenewal, ExpireShortLivedCertificates
- scheduler.go: 30s short-lived expiry check loop
- certificate.go (repo): nullable profile FK, key metadata on versions
- main.go: profile repo/service/handler wiring, 8-param NewRenewalService
- router.go: 12-param RegisterHandlers with profile routes
- seed_demo.sql: 4 demo profiles (standard, mtls, short-lived, high-security)
- Frontend: types, API client, routing, sidebar nav

Tests: 40 new tests across handler (15), service (13), crypto validation (12)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 20:39:49 -04:00

364 lines
11 KiB
Go

package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// CertificateRepository implements repository.CertificateRepository
type CertificateRepository struct {
db *sql.DB
}
// NewCertificateRepository creates a new CertificateRepository
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
return &CertificateRepository{db: db}
}
// List returns a paginated list of certificates matching the filter criteria
func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
if filter == nil {
filter = &repository.CertificateFilter{}
}
// Set defaults
if filter.Page < 1 {
filter.Page = 1
}
if filter.PerPage == 0 || filter.PerPage > 500 {
filter.PerPage = 50
}
// Build WHERE clause
var whereConditions []string
var args []interface{}
argCount := 1
if filter.Status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
args = append(args, filter.Status)
argCount++
}
if filter.Environment != "" {
whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount))
args = append(args, filter.Environment)
argCount++
}
if filter.OwnerID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount))
args = append(args, filter.OwnerID)
argCount++
}
if filter.TeamID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount))
args = append(args, filter.TeamID)
argCount++
}
if filter.IssuerID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount))
args = append(args, filter.IssuerID)
argCount++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause)
var total int
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
}
// Get paginated results
offset := (filter.Page - 1) * filter.PerPage
query := fmt.Sprintf(`
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argCount, argCount+1)
args = append(args, filter.PerPage, offset)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query certificates: %w", err)
}
defer rows.Close()
var certs []*domain.ManagedCertificate
for rows.Next() {
cert, err := scanCertificate(rows)
if err != nil {
return nil, 0, err
}
certs = append(certs, cert)
}
if err := rows.Err(); err != nil {
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
}
return certs, total, nil
}
// Get retrieves a certificate by ID
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
WHERE id = $1
`, id)
cert, err := scanCertificate(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("certificate not found")
}
return nil, fmt.Errorf("failed to query certificate: %w", err)
}
return cert, nil
}
// Create stores a new certificate
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
if cert.ID == "" {
cert.ID = uuid.New().String()
}
tagsJSON, err := json.Marshal(cert.Tags)
if err != nil {
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO managed_certificates (
id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
RETURNING id
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID,
cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
return nil
}
// Update modifies an existing certificate
func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
tagsJSON, err := json.Marshal(cert.Tags)
if err != nil {
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET
name = $1,
common_name = $2,
sans = $3,
environment = $4,
owner_id = $5,
team_id = $6,
issuer_id = $7,
certificate_profile_id = $8,
status = $9,
expires_at = $10,
tags = $11,
last_renewal_at = $12,
last_deployment_at = $13,
updated_at = $14
WHERE id = $15
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
if err != nil {
return fmt.Errorf("failed to update certificate: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("certificate not found")
}
return nil
}
// Archive marks a certificate as archived
func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET status = $1, updated_at = $2 WHERE id = $3
`, domain.CertificateStatusArchived, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to archive certificate: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("certificate not found")
}
return nil
}
// ListVersions returns all versions of a certificate
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
`, certID)
if err != nil {
return nil, fmt.Errorf("failed to query certificate versions: %w", err)
}
defer rows.Close()
var versions []*domain.CertificateVersion
for rows.Next() {
var v domain.CertificateVersion
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
}
versions = append(versions, &v)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating version rows: %w", err)
}
return versions, nil
}
// CreateVersion stores a new certificate version
func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
if version.ID == "" {
version.ID = uuid.New().String()
}
err := r.db.QueryRowContext(ctx, `
INSERT INTO certificate_versions (
id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
if err != nil {
return fmt.Errorf("failed to create certificate version: %w", err)
}
return nil
}
// GetExpiringCertificates returns certificates expiring before the given time
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
FROM managed_certificates
WHERE expires_at < $1 AND status != $2
ORDER BY expires_at ASC
`, before, domain.CertificateStatusArchived)
if err != nil {
return nil, fmt.Errorf("failed to query expiring certificates: %w", err)
}
defer rows.Close()
var certs []*domain.ManagedCertificate
for rows.Next() {
cert, err := scanCertificate(rows)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
}
return certs, nil
}
// scanCertificate scans a certificate from a row or rows
func scanCertificate(scanner interface {
Scan(...interface{}) error
}) (*domain.ManagedCertificate, error) {
var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
err := scanner.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan certificate: %w", err)
}
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
return &cert, nil
}