mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 15:51:30 +00:00
a579a84c7f
Add certificate profiles as named enrollment templates that control allowed key algorithms, max TTL, permitted EKUs, required SAN patterns, and optional SPIFFE URI SANs. CSR submissions are validated against profile rules at signing time (key type + minimum size). Short-lived certs (TTL < 1 hour) auto-expire via a new scheduler loop — expiry acts as revocation, no CRL/OCSP needed. New files: - Migration 000003: certificate_profiles table, FK columns on managed_certificates/renewal_policies, key metadata on certificate_versions - domain/profile.go: CertificateProfile + KeyAlgorithmRule structs - repository/postgres/profile.go: full CRUD with JSONB marshaling - service/profile.go: ProfileService with validation + audit logging - service/crypto_validation.go: CSR-against-profile validation (RSA/ECDSA/Ed25519) - handler/profiles.go: 5 HTTP endpoints under /api/v1/profiles - web/src/pages/ProfilesPage.tsx: profiles management page Modified: - renewal.go: CSR validation in CompleteAgentCSRRenewal, ExpireShortLivedCertificates - scheduler.go: 30s short-lived expiry check loop - certificate.go (repo): nullable profile FK, key metadata on versions - main.go: profile repo/service/handler wiring, 8-param NewRenewalService - router.go: 12-param RegisterHandlers with profile routes - seed_demo.sql: 4 demo profiles (standard, mtls, short-lived, high-security) - Frontend: types, API client, routing, sidebar nav Tests: 40 new tests across handler (15), service (13), crypto validation (12) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
364 lines
11 KiB
Go
364 lines
11 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
"github.com/shankar0123/certctl/internal/domain"
|
|
"github.com/shankar0123/certctl/internal/repository"
|
|
)
|
|
|
|
// CertificateRepository implements repository.CertificateRepository
|
|
type CertificateRepository struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewCertificateRepository creates a new CertificateRepository
|
|
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
|
|
return &CertificateRepository{db: db}
|
|
}
|
|
|
|
// List returns a paginated list of certificates matching the filter criteria
|
|
func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
|
if filter == nil {
|
|
filter = &repository.CertificateFilter{}
|
|
}
|
|
|
|
// Set defaults
|
|
if filter.Page < 1 {
|
|
filter.Page = 1
|
|
}
|
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
|
filter.PerPage = 50
|
|
}
|
|
|
|
// Build WHERE clause
|
|
var whereConditions []string
|
|
var args []interface{}
|
|
argCount := 1
|
|
|
|
if filter.Status != "" {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
|
args = append(args, filter.Status)
|
|
argCount++
|
|
}
|
|
if filter.Environment != "" {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount))
|
|
args = append(args, filter.Environment)
|
|
argCount++
|
|
}
|
|
if filter.OwnerID != "" {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount))
|
|
args = append(args, filter.OwnerID)
|
|
argCount++
|
|
}
|
|
if filter.TeamID != "" {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount))
|
|
args = append(args, filter.TeamID)
|
|
argCount++
|
|
}
|
|
if filter.IssuerID != "" {
|
|
whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount))
|
|
args = append(args, filter.IssuerID)
|
|
argCount++
|
|
}
|
|
|
|
whereClause := ""
|
|
if len(whereConditions) > 0 {
|
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
|
}
|
|
|
|
// Get total count
|
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause)
|
|
var total int
|
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
|
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
|
|
}
|
|
|
|
// Get paginated results
|
|
offset := (filter.Page - 1) * filter.PerPage
|
|
query := fmt.Sprintf(`
|
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
|
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
|
FROM managed_certificates
|
|
%s
|
|
ORDER BY created_at DESC
|
|
LIMIT $%d OFFSET $%d
|
|
`, whereClause, argCount, argCount+1)
|
|
|
|
args = append(args, filter.PerPage, offset)
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("failed to query certificates: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var certs []*domain.ManagedCertificate
|
|
for rows.Next() {
|
|
cert, err := scanCertificate(rows)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
certs = append(certs, cert)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
|
}
|
|
|
|
return certs, total, nil
|
|
}
|
|
|
|
// Get retrieves a certificate by ID
|
|
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
|
row := r.db.QueryRowContext(ctx, `
|
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
|
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
|
FROM managed_certificates
|
|
WHERE id = $1
|
|
`, id)
|
|
|
|
cert, err := scanCertificate(row)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("certificate not found")
|
|
}
|
|
return nil, fmt.Errorf("failed to query certificate: %w", err)
|
|
}
|
|
|
|
return cert, nil
|
|
}
|
|
|
|
// Create stores a new certificate
|
|
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
|
|
if cert.ID == "" {
|
|
cert.ID = uuid.New().String()
|
|
}
|
|
|
|
tagsJSON, err := json.Marshal(cert.Tags)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
|
}
|
|
|
|
var profileID *string
|
|
if cert.CertificateProfileID != "" {
|
|
profileID = &cert.CertificateProfileID
|
|
}
|
|
|
|
err = r.db.QueryRowContext(ctx, `
|
|
INSERT INTO managed_certificates (
|
|
id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
|
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
|
|
RETURNING id
|
|
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
|
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID,
|
|
cert.Status, cert.ExpiresAt,
|
|
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create certificate: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Update modifies an existing certificate
|
|
func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
|
|
tagsJSON, err := json.Marshal(cert.Tags)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
|
}
|
|
|
|
var profileID *string
|
|
if cert.CertificateProfileID != "" {
|
|
profileID = &cert.CertificateProfileID
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `
|
|
UPDATE managed_certificates SET
|
|
name = $1,
|
|
common_name = $2,
|
|
sans = $3,
|
|
environment = $4,
|
|
owner_id = $5,
|
|
team_id = $6,
|
|
issuer_id = $7,
|
|
certificate_profile_id = $8,
|
|
status = $9,
|
|
expires_at = $10,
|
|
tags = $11,
|
|
last_renewal_at = $12,
|
|
last_deployment_at = $13,
|
|
updated_at = $14
|
|
WHERE id = $15
|
|
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
|
cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt,
|
|
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update certificate: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rows == 0 {
|
|
return fmt.Errorf("certificate not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Archive marks a certificate as archived
|
|
func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
|
|
result, err := r.db.ExecContext(ctx, `
|
|
UPDATE managed_certificates SET status = $1, updated_at = $2 WHERE id = $3
|
|
`, domain.CertificateStatusArchived, time.Now(), id)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to archive certificate: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rows == 0 {
|
|
return fmt.Errorf("certificate not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListVersions returns all versions of a certificate
|
|
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
|
rows, err := r.db.QueryContext(ctx, `
|
|
SELECT id, certificate_id, serial_number, not_before, not_after,
|
|
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
|
|
FROM certificate_versions
|
|
WHERE certificate_id = $1
|
|
ORDER BY created_at DESC
|
|
`, certID)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query certificate versions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var versions []*domain.CertificateVersion
|
|
for rows.Next() {
|
|
var v domain.CertificateVersion
|
|
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
|
|
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil {
|
|
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
|
|
}
|
|
versions = append(versions, &v)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating version rows: %w", err)
|
|
}
|
|
|
|
return versions, nil
|
|
}
|
|
|
|
// CreateVersion stores a new certificate version
|
|
func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
|
|
if version.ID == "" {
|
|
version.ID = uuid.New().String()
|
|
}
|
|
|
|
err := r.db.QueryRowContext(ctx, `
|
|
INSERT INTO certificate_versions (
|
|
id, certificate_id, serial_number, not_before, not_after,
|
|
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
|
RETURNING id
|
|
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
|
|
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create certificate version: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetExpiringCertificates returns certificates expiring before the given time
|
|
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
|
|
rows, err := r.db.QueryContext(ctx, `
|
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
|
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
|
FROM managed_certificates
|
|
WHERE expires_at < $1 AND status != $2
|
|
ORDER BY expires_at ASC
|
|
`, before, domain.CertificateStatusArchived)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query expiring certificates: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var certs []*domain.ManagedCertificate
|
|
for rows.Next() {
|
|
cert, err := scanCertificate(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
certs = append(certs, cert)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
|
}
|
|
|
|
return certs, nil
|
|
}
|
|
|
|
// scanCertificate scans a certificate from a row or rows
|
|
func scanCertificate(scanner interface {
|
|
Scan(...interface{}) error
|
|
}) (*domain.ManagedCertificate, error) {
|
|
var cert domain.ManagedCertificate
|
|
var tagsJSON []byte
|
|
var sans pq.StringArray
|
|
var profileID sql.NullString
|
|
|
|
err := scanner.Scan(
|
|
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
|
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
|
|
&cert.Status, &cert.ExpiresAt, &tagsJSON,
|
|
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan certificate: %w", err)
|
|
}
|
|
|
|
cert.SANs = []string(sans)
|
|
if profileID.Valid {
|
|
cert.CertificateProfileID = profileID.String
|
|
}
|
|
|
|
// Unmarshal tags
|
|
if len(tagsJSON) > 0 {
|
|
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
|
|
}
|
|
} else {
|
|
cert.Tags = make(map[string]string)
|
|
}
|
|
|
|
return &cert, nil
|
|
}
|