From a579a84c7f9d4f68ace52ef57e86ff95e76f6568 Mon Sep 17 00:00:00 2001 From: shankar0123 Date: Fri, 20 Mar 2026 20:39:49 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20M11a=20=E2=80=94=20certificate=20profil?= =?UTF-8?q?es,=20crypto=20policy=20enforcement,=20short-lived=20cert=20exp?= =?UTF-8?q?iry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add certificate profiles as named enrollment templates that control allowed key algorithms, max TTL, permitted EKUs, required SAN patterns, and optional SPIFFE URI SANs. CSR submissions are validated against profile rules at signing time (key type + minimum size). Short-lived certs (TTL < 1 hour) auto-expire via a new scheduler loop — expiry acts as revocation, no CRL/OCSP needed. New files: - Migration 000003: certificate_profiles table, FK columns on managed_certificates/renewal_policies, key metadata on certificate_versions - domain/profile.go: CertificateProfile + KeyAlgorithmRule structs - repository/postgres/profile.go: full CRUD with JSONB marshaling - service/profile.go: ProfileService with validation + audit logging - service/crypto_validation.go: CSR-against-profile validation (RSA/ECDSA/Ed25519) - handler/profiles.go: 5 HTTP endpoints under /api/v1/profiles - web/src/pages/ProfilesPage.tsx: profiles management page Modified: - renewal.go: CSR validation in CompleteAgentCSRRenewal, ExpireShortLivedCertificates - scheduler.go: 30s short-lived expiry check loop - certificate.go (repo): nullable profile FK, key metadata on versions - main.go: profile repo/service/handler wiring, 8-param NewRenewalService - router.go: 12-param RegisterHandlers with profile routes - seed_demo.sql: 4 demo profiles (standard, mtls, short-lived, high-security) - Frontend: types, API client, routing, sidebar nav Tests: 40 new tests across handler (15), service (13), crypto validation (12) Co-Authored-By: Claude Opus 4.6 --- cmd/server/main.go | 6 +- internal/api/handler/profile_handler_test.go | 429 ++++++++++++++++++ internal/api/handler/profiles.go | 206 +++++++++ internal/api/router/router.go | 8 + internal/domain/certificate.go | 56 +-- internal/domain/profile.go | 71 +++ internal/integration/lifecycle_test.go | 27 +- internal/integration/negative_test.go | 4 +- internal/repository/interfaces.go | 14 + internal/repository/postgres/certificate.go | 57 ++- internal/repository/postgres/profile.go | 226 +++++++++ internal/scheduler/scheduler.go | 49 +- internal/service/crypto_validation.go | 85 ++++ internal/service/crypto_validation_test.go | 244 ++++++++++ internal/service/job_test.go | 2 +- internal/service/profile.go | 181 ++++++++ internal/service/profile_test.go | 415 +++++++++++++++++ internal/service/renewal.go | 92 ++++ internal/service/renewal_test.go | 24 +- .../000003_certificate_profiles.down.sql | 13 + migrations/000003_certificate_profiles.up.sql | 53 +++ migrations/seed_demo.sql | 36 ++ web/src/api/client.ts | 20 +- web/src/api/types.ts | 20 + web/src/components/Layout.tsx | 1 + web/src/main.tsx | 2 + web/src/pages/ProfilesPage.tsx | 129 ++++++ 27 files changed, 2399 insertions(+), 71 deletions(-) create mode 100644 internal/api/handler/profile_handler_test.go create mode 100644 internal/api/handler/profiles.go create mode 100644 internal/domain/profile.go create mode 100644 internal/repository/postgres/profile.go create mode 100644 internal/service/crypto_validation.go create mode 100644 internal/service/crypto_validation_test.go create mode 100644 internal/service/profile.go create mode 100644 internal/service/profile_test.go create mode 100644 migrations/000003_certificate_profiles.down.sql create mode 100644 migrations/000003_certificate_profiles.up.sql create mode 100644 web/src/pages/ProfilesPage.tsx diff --git a/cmd/server/main.go b/cmd/server/main.go index e89bd26..4ddd449 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -68,6 +68,7 @@ func main() { policyRepo := postgres.NewPolicyRepository(db) notificationRepo := postgres.NewNotificationRepository(db) renewalPolicyRepo := postgres.NewRenewalPolicyRepository(db) + profileRepo := postgres.NewProfileRepository(db) teamRepo := postgres.NewTeamRepository(db) ownerRepo := postgres.NewOwnerRepository(db) logger.Info("initialized all repositories") @@ -102,12 +103,13 @@ func main() { policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certificateRepo, policyService, auditService) notificationService := service.NewNotificationService(notificationRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode) + renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode) deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certificateRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) issuerService := service.NewIssuerService(issuerRepo, auditService) targetService := service.NewTargetService(targetRepo, auditService) + profileService := service.NewProfileService(profileRepo, auditService) teamService := service.NewTeamService(teamRepo, auditService) ownerService := service.NewOwnerService(ownerRepo, auditService) logger.Info("initialized all services") @@ -119,6 +121,7 @@ func main() { agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(profileService) teamHandler := handler.NewTeamHandler(teamService) ownerHandler := handler.NewOwnerHandler(ownerService) auditHandler := handler.NewAuditHandler(auditService) @@ -160,6 +163,7 @@ func main() { agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, diff --git a/internal/api/handler/profile_handler_test.go b/internal/api/handler/profile_handler_test.go new file mode 100644 index 0000000..a769ba1 --- /dev/null +++ b/internal/api/handler/profile_handler_test.go @@ -0,0 +1,429 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/shankar0123/certctl/internal/domain" +) + +// MockProfileService is a mock implementation of ProfileService interface. +type MockProfileService struct { + ListProfilesFn func(page, perPage int) ([]domain.CertificateProfile, int64, error) + GetProfileFn func(id string) (*domain.CertificateProfile, error) + CreateProfileFn func(profile domain.CertificateProfile) (*domain.CertificateProfile, error) + UpdateProfileFn func(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) + DeleteProfileFn func(id string) error +} + +func (m *MockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + if m.ListProfilesFn != nil { + return m.ListProfilesFn(page, perPage) + } + return nil, 0, nil +} + +func (m *MockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + if m.GetProfileFn != nil { + return m.GetProfileFn(id) + } + return nil, nil +} + +func (m *MockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if m.CreateProfileFn != nil { + return m.CreateProfileFn(profile) + } + return nil, nil +} + +func (m *MockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if m.UpdateProfileFn != nil { + return m.UpdateProfileFn(id, profile) + } + return nil, nil +} + +func (m *MockProfileService) DeleteProfile(id string) error { + if m.DeleteProfileFn != nil { + return m.DeleteProfileFn(id) + } + return nil +} + +func TestListProfiles_Success(t *testing.T) { + now := time.Now() + prof1 := domain.CertificateProfile{ + ID: "prof-standard-tls", + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + }, + MaxTTLSeconds: 7776000, + AllowedEKUs: []string{"serverAuth"}, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } + prof2 := domain.CertificateProfile{ + ID: "prof-internal-mtls", + Name: "Internal mTLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + }, + MaxTTLSeconds: 2592000, + AllowedEKUs: []string{"serverAuth", "clientAuth"}, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } + + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return []domain.CertificateProfile{prof1, prof2}, 2, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp PagedResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Total != 2 { + t.Errorf("expected total 2, got %d", resp.Total) + } +} + +func TestListProfiles_Pagination(t *testing.T) { + var capturedPage, capturedPerPage int + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + capturedPage = page + capturedPerPage = perPage + return []domain.CertificateProfile{}, 0, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles?page=3&per_page=25", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if capturedPage != 3 { + t.Errorf("expected page 3, got %d", capturedPage) + } + if capturedPerPage != 25 { + t.Errorf("expected per_page 25, got %d", capturedPerPage) + } +} + +func TestListProfiles_ServiceError(t *testing.T) { + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return nil, 0, ErrMockServiceFailed + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", w.Code) + } +} + +func TestListProfiles_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles", nil) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} + +func TestGetProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + GetProfileFn: func(id string) (*domain.CertificateProfile, error) { + return &domain.CertificateProfile{ + ID: id, + Name: "Standard TLS", + MaxTTLSeconds: 7776000, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + }, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/prof-standard-tls", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } +} + +func TestGetProfile_NotFound(t *testing.T) { + mock := &MockProfileService{ + GetProfileFn: func(id string) (*domain.CertificateProfile, error) { + return nil, ErrMockNotFound + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/nonexistent", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected status 404, got %d", w.Code) + } +} + +func TestGetProfile_EmptyID(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + CreateProfileFn: func(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = "prof-new" + profile.CreatedAt = now + profile.UpdatedAt = now + return &profile, nil + }, + } + + body := map[string]interface{}{ + "name": "New Profile", + "max_ttl_seconds": 86400, + "allowed_ekus": []string{"serverAuth"}, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestCreateProfile_MissingName(t *testing.T) { + body := map[string]interface{}{ + "max_ttl_seconds": 86400, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_NameTooLong(t *testing.T) { + longName := "" + for i := 0; i < 256; i++ { + longName += "x" + } + body := map[string]interface{}{ + "name": longName, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_InvalidJSON(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader([]byte("{invalid"))) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} + +func TestUpdateProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + UpdateProfileFn: func(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = id + profile.UpdatedAt = now + return &profile, nil + }, + } + + body := map[string]interface{}{ + "name": "Updated Profile", + "max_ttl_seconds": 172800, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodPut, "/api/v1/profiles/prof-standard-tls", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.UpdateProfile(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateProfile_InvalidJSON(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPut, "/api/v1/profiles/prof-x", bytes.NewReader([]byte("{bad"))) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.UpdateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestDeleteProfile_Success(t *testing.T) { + var deletedID string + mock := &MockProfileService{ + DeleteProfileFn: func(id string) error { + deletedID = id + return nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/prof-standard-tls", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("expected status 204, got %d", w.Code) + } + if deletedID != "prof-standard-tls" { + t.Errorf("expected deleted ID 'prof-standard-tls', got '%s'", deletedID) + } +} + +func TestDeleteProfile_ServiceError(t *testing.T) { + mock := &MockProfileService{ + DeleteProfileFn: func(id string) error { + return ErrMockServiceFailed + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/prof-x", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", w.Code) + } +} + +func TestDeleteProfile_EmptyID(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestDeleteProfile_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/prof-x", nil) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} diff --git a/internal/api/handler/profiles.go b/internal/api/handler/profiles.go new file mode 100644 index 0000000..899e0ac --- /dev/null +++ b/internal/api/handler/profiles.go @@ -0,0 +1,206 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/shankar0123/certctl/internal/api/middleware" + "github.com/shankar0123/certctl/internal/domain" +) + +// ProfileService defines the service interface for certificate profile operations. +type ProfileService interface { + ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) + GetProfile(id string) (*domain.CertificateProfile, error) + CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) + UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) + DeleteProfile(id string) error +} + +// ProfileHandler handles HTTP requests for certificate profile operations. +type ProfileHandler struct { + svc ProfileService +} + +// NewProfileHandler creates a new ProfileHandler with a service dependency. +func NewProfileHandler(svc ProfileService) ProfileHandler { + return ProfileHandler{svc: svc} +} + +// ListProfiles lists all certificate profiles. +// GET /api/v1/profiles?page=1&per_page=50 +func (h ProfileHandler) ListProfiles(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + page := 1 + perPage := 50 + query := r.URL.Query() + if p := query.Get("page"); p != "" { + if parsed, err := strconv.Atoi(p); err == nil && parsed > 0 { + page = parsed + } + } + if pp := query.Get("per_page"); pp != "" { + if parsed, err := strconv.Atoi(pp); err == nil && parsed > 0 && parsed <= 500 { + perPage = parsed + } + } + + profiles, total, err := h.svc.ListProfiles(page, perPage) + if err != nil { + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID) + return + } + + response := PagedResponse{ + Data: profiles, + Total: total, + Page: page, + PerPage: perPage, + } + + JSON(w, http.StatusOK, response) +} + +// GetProfile retrieves a single certificate profile by ID. +// GET /api/v1/profiles/{id} +func (h ProfileHandler) GetProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + if id == "" || strings.Contains(id, "/") { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + + profile, err := h.svc.GetProfile(id) + if err != nil { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + + JSON(w, http.StatusOK, profile) +} + +// CreateProfile creates a new certificate profile. +// POST /api/v1/profiles +func (h ProfileHandler) CreateProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + var profile domain.CertificateProfile + if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, "Invalid request body", requestID) + return + } + + // Validate required fields + if err := ValidateRequired("name", profile.Name); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + if err := ValidateStringLength("name", profile.Name, 255); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + + created, err := h.svc.CreateProfile(profile) + if err != nil { + // Check if it's a validation error from the service + if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") || + strings.Contains(err.Error(), "must be") || strings.Contains(err.Error(), "cannot") { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create profile", requestID) + return + } + + JSON(w, http.StatusCreated, created) +} + +// UpdateProfile updates an existing certificate profile. +// PUT /api/v1/profiles/{id} +func (h ProfileHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + parts := strings.Split(id, "/") + if len(parts) == 0 || parts[0] == "" { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + id = parts[0] + + var profile domain.CertificateProfile + if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, "Invalid request body", requestID) + return + } + + updated, err := h.svc.UpdateProfile(id, profile) + if err != nil { + if strings.Contains(err.Error(), "not found") { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") || + strings.Contains(err.Error(), "must be") || strings.Contains(err.Error(), "cannot") { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update profile", requestID) + return + } + + JSON(w, http.StatusOK, updated) +} + +// DeleteProfile deletes a certificate profile. +// DELETE /api/v1/profiles/{id} +func (h ProfileHandler) DeleteProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + if id == "" || strings.Contains(id, "/") { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + + if err := h.svc.DeleteProfile(id); err != nil { + if strings.Contains(err.Error(), "not found") { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete profile", requestID) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/router/router.go b/internal/api/router/router.go index 47d985d..a5d416f 100644 --- a/internal/api/router/router.go +++ b/internal/api/router/router.go @@ -51,6 +51,7 @@ func (r *Router) RegisterHandlers( agents handler.AgentHandler, jobs handler.JobHandler, policies handler.PolicyHandler, + profiles handler.ProfileHandler, teams handler.TeamHandler, owners handler.OwnerHandler, audit handler.AuditHandler, @@ -125,6 +126,13 @@ func (r *Router) RegisterHandlers( r.Register("DELETE /api/v1/policies/{id}", http.HandlerFunc(policies.DeletePolicy)) r.Register("GET /api/v1/policies/{id}/violations", http.HandlerFunc(policies.ListViolations)) + // Profiles routes: /api/v1/profiles + r.Register("GET /api/v1/profiles", http.HandlerFunc(profiles.ListProfiles)) + r.Register("POST /api/v1/profiles", http.HandlerFunc(profiles.CreateProfile)) + r.Register("GET /api/v1/profiles/{id}", http.HandlerFunc(profiles.GetProfile)) + r.Register("PUT /api/v1/profiles/{id}", http.HandlerFunc(profiles.UpdateProfile)) + r.Register("DELETE /api/v1/profiles/{id}", http.HandlerFunc(profiles.DeleteProfile)) + // Teams routes: /api/v1/teams r.Register("GET /api/v1/teams", http.HandlerFunc(teams.ListTeams)) r.Register("POST /api/v1/teams", http.HandlerFunc(teams.CreateTeam)) diff --git a/internal/domain/certificate.go b/internal/domain/certificate.go index 0622738..61ea320 100644 --- a/internal/domain/certificate.go +++ b/internal/domain/certificate.go @@ -6,23 +6,24 @@ import ( // ManagedCertificate represents a certificate managed by the control plane. type ManagedCertificate struct { - ID string `json:"id"` - Name string `json:"name"` - CommonName string `json:"common_name"` - SANs []string `json:"sans"` - Environment string `json:"environment"` - OwnerID string `json:"owner_id"` - TeamID string `json:"team_id"` - IssuerID string `json:"issuer_id"` - TargetIDs []string `json:"target_ids"` - RenewalPolicyID string `json:"renewal_policy_id"` - Status CertificateStatus `json:"status"` - ExpiresAt time.Time `json:"expires_at"` - Tags map[string]string `json:"tags"` - LastRenewalAt *time.Time `json:"last_renewal_at,omitempty"` - LastDeploymentAt *time.Time `json:"last_deployment_at,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + CommonName string `json:"common_name"` + SANs []string `json:"sans"` + Environment string `json:"environment"` + OwnerID string `json:"owner_id"` + TeamID string `json:"team_id"` + IssuerID string `json:"issuer_id"` + TargetIDs []string `json:"target_ids"` + RenewalPolicyID string `json:"renewal_policy_id"` + CertificateProfileID string `json:"certificate_profile_id,omitempty"` + Status CertificateStatus `json:"status"` + ExpiresAt time.Time `json:"expires_at"` + Tags map[string]string `json:"tags"` + LastRenewalAt *time.Time `json:"last_renewal_at,omitempty"` + LastDeploymentAt *time.Time `json:"last_deployment_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // CertificateVersion represents a specific version of a certificate. @@ -35,6 +36,8 @@ type CertificateVersion struct { FingerprintSHA256 string `json:"fingerprint_sha256"` PEMChain string `json:"pem_chain"` CSRPEM string `json:"csr_pem"` + KeyAlgorithm string `json:"key_algorithm,omitempty"` + KeySize int `json:"key_size,omitempty"` CreatedAt time.Time `json:"created_at"` } @@ -54,15 +57,16 @@ const ( // RenewalPolicy defines renewal parameters for a managed certificate. type RenewalPolicy struct { - ID string `json:"id"` - Name string `json:"name"` - RenewalWindowDays int `json:"renewal_window_days"` - AutoRenew bool `json:"auto_renew"` - MaxRetries int `json:"max_retries"` - RetryInterval int `json:"retry_interval_seconds"` - AlertThresholdsDays []int `json:"alert_thresholds_days"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + RenewalWindowDays int `json:"renewal_window_days"` + AutoRenew bool `json:"auto_renew"` + MaxRetries int `json:"max_retries"` + RetryInterval int `json:"retry_interval_seconds"` + AlertThresholdsDays []int `json:"alert_thresholds_days"` + CertificateProfileID string `json:"certificate_profile_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // DefaultAlertThresholds returns the standard alert thresholds when none are configured. diff --git a/internal/domain/profile.go b/internal/domain/profile.go new file mode 100644 index 0000000..c89b385 --- /dev/null +++ b/internal/domain/profile.go @@ -0,0 +1,71 @@ +package domain + +import ( + "time" +) + +// CertificateProfile defines an enrollment profile that controls what kinds of +// certificates can be issued: allowed key algorithms, maximum TTL, permitted EKUs, +// required SAN patterns, and optional SPIFFE URI SANs for workload identity. +type CertificateProfile struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + AllowedKeyAlgorithms []KeyAlgorithmRule `json:"allowed_key_algorithms"` + MaxTTLSeconds int `json:"max_ttl_seconds"` + AllowedEKUs []string `json:"allowed_ekus"` + RequiredSANPatterns []string `json:"required_san_patterns"` + SPIFFEURIPattern string `json:"spiffe_uri_pattern"` + AllowShortLived bool `json:"allow_short_lived"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// KeyAlgorithmRule defines an allowed key algorithm and its minimum key size. +type KeyAlgorithmRule struct { + Algorithm string `json:"algorithm"` // "RSA", "ECDSA", "Ed25519" + MinSize int `json:"min_size"` // RSA: 2048/4096, ECDSA: 256/384, Ed25519: 0 (fixed) +} + +// IsShortLived returns true if this profile's max TTL is under 1 hour (3600 seconds). +// Short-lived certs use expiry as revocation — no CRL/OCSP needed. +func (p *CertificateProfile) IsShortLived() bool { + return p.AllowShortLived && p.MaxTTLSeconds > 0 && p.MaxTTLSeconds < 3600 +} + +// DefaultKeyAlgorithms returns sensible defaults for profiles without explicit rules. +func DefaultKeyAlgorithms() []KeyAlgorithmRule { + return []KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + } +} + +// DefaultEKUs returns the default extended key usages. +func DefaultEKUs() []string { + return []string{"serverAuth"} +} + +// Supported key algorithm constants for validation. +const ( + KeyAlgorithmRSA = "RSA" + KeyAlgorithmECDSA = "ECDSA" + KeyAlgorithmEd25519 = "Ed25519" +) + +// ValidKeyAlgorithms is the set of recognized key algorithm names. +var ValidKeyAlgorithms = map[string]bool{ + KeyAlgorithmRSA: true, + KeyAlgorithmECDSA: true, + KeyAlgorithmEd25519: true, +} + +// ValidEKUs is the set of recognized extended key usage names. +var ValidEKUs = map[string]bool{ + "serverAuth": true, + "clientAuth": true, + "codeSigning": true, + "emailProtection": true, + "timeStamping": true, +} diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index 506c557..805b301 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -52,7 +52,7 @@ func TestCertificateLifecycle(t *testing.T) { policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certRepo, policyService, auditService) notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, "server") + renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server") deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) @@ -65,6 +65,7 @@ func TestCertificateLifecycle(t *testing.T) { agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(&mockProfileService{}) teamHandler := handler.NewTeamHandler(&mockTeamService{}) ownerHandler := handler.NewOwnerHandler(&mockOwnerService{}) auditHandler := handler.NewAuditHandler(auditService) @@ -80,6 +81,7 @@ func TestCertificateLifecycle(t *testing.T) { agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, @@ -994,3 +996,26 @@ func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.O func (m *mockOwnerService) DeleteOwner(id string) error { return nil } + +type mockProfileService struct{} + +func (m *mockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return []domain.CertificateProfile{}, 0, nil +} + +func (m *mockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + return nil, fmt.Errorf("profile not found") +} + +func (m *mockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + return &profile, nil +} + +func (m *mockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = id + return &profile, nil +} + +func (m *mockProfileService) DeleteProfile(id string) error { + return nil +} diff --git a/internal/integration/negative_test.go b/internal/integration/negative_test.go index 04f487e..b385cf4 100644 --- a/internal/integration/negative_test.go +++ b/internal/integration/negative_test.go @@ -43,7 +43,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certRepo, policyService, auditService) notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, "server") + renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server") deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) @@ -55,6 +55,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(&mockProfileService{}) teamHandler := handler.NewTeamHandler(&mockTeamService{}) ownerHandler := handler.NewOwnerHandler(&mockOwnerService{}) auditHandler := handler.NewAuditHandler(auditService) @@ -69,6 +70,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index ed647e7..d988890 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -155,6 +155,20 @@ type TeamRepository interface { Delete(ctx context.Context, id string) error } +// CertificateProfileRepository defines operations for managing certificate profiles. +type CertificateProfileRepository interface { + // List returns all certificate profiles. + List(ctx context.Context) ([]*domain.CertificateProfile, error) + // Get retrieves a certificate profile by ID. + Get(ctx context.Context, id string) (*domain.CertificateProfile, error) + // Create stores a new certificate profile. + Create(ctx context.Context, profile *domain.CertificateProfile) error + // Update modifies an existing certificate profile. + Update(ctx context.Context, profile *domain.CertificateProfile) error + // Delete removes a certificate profile. + Delete(ctx context.Context, id string) error +} + // OwnerRepository defines operations for managing certificate owners. type OwnerRepository interface { // List returns all owners. diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go index cad8154..5d3b50e 100644 --- a/internal/repository/postgres/certificate.go +++ b/internal/repository/postgres/certificate.go @@ -85,7 +85,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer offset := (filter.Page - 1) * filter.PerPage query := fmt.Sprintf(` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates %s ORDER BY created_at DESC @@ -120,7 +120,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) { row := r.db.QueryRowContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE id = $1 `, id) @@ -147,14 +147,20 @@ func (r *CertificateRepository) Create(ctx context.Context, cert *domain.Managed return fmt.Errorf("failed to marshal tags: %w", err) } + var profileID *string + if cert.CertificateProfileID != "" { + profileID = &cert.CertificateProfileID + } + err = r.db.QueryRowContext(ctx, ` INSERT INTO managed_certificates ( id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id `, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, - cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, cert.Status, cert.ExpiresAt, + cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID, + cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID) if err != nil { @@ -171,6 +177,11 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed return fmt.Errorf("failed to marshal tags: %w", err) } + var profileID *string + if cert.CertificateProfileID != "" { + profileID = &cert.CertificateProfileID + } + result, err := r.db.ExecContext(ctx, ` UPDATE managed_certificates SET name = $1, @@ -180,15 +191,16 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed owner_id = $5, team_id = $6, issuer_id = $7, - status = $8, - expires_at = $9, - tags = $10, - last_renewal_at = $11, - last_deployment_at = $12, - updated_at = $13 - WHERE id = $14 + certificate_profile_id = $8, + status = $9, + expires_at = $10, + tags = $11, + last_renewal_at = $12, + last_deployment_at = $13, + updated_at = $14 + WHERE id = $15 `, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, - cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt, + cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID) if err != nil { @@ -233,7 +245,7 @@ func (r *CertificateRepository) Archive(ctx context.Context, id string) error { func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, certificate_id, serial_number, not_before, not_after, - fingerprint_sha256, pem_chain, csr_pem, created_at + fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at FROM certificate_versions WHERE certificate_id = $1 ORDER BY created_at DESC @@ -248,7 +260,7 @@ func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) for rows.Next() { var v domain.CertificateVersion if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter, - &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil { + &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil { return nil, fmt.Errorf("failed to scan certificate version: %w", err) } versions = append(versions, &v) @@ -270,11 +282,11 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma err := r.db.QueryRowContext(ctx, ` INSERT INTO certificate_versions ( id, certificate_id, serial_number, not_before, not_after, - fingerprint_sha256, pem_chain, csr_pem, created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id `, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter, - version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID) + version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID) if err != nil { return fmt.Errorf("failed to create certificate version: %w", err) @@ -287,7 +299,7 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE expires_at < $1 AND status != $2 ORDER BY expires_at ASC @@ -321,10 +333,12 @@ func scanCertificate(scanner interface { var cert domain.ManagedCertificate var tagsJSON []byte var sans pq.StringArray + var profileID sql.NullString err := scanner.Scan( &cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID, - &cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &cert.Status, &cert.ExpiresAt, &tagsJSON, + &cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID, + &cert.Status, &cert.ExpiresAt, &tagsJSON, &cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt) if err != nil { @@ -332,6 +346,9 @@ func scanCertificate(scanner interface { } cert.SANs = []string(sans) + if profileID.Valid { + cert.CertificateProfileID = profileID.String + } // Unmarshal tags if len(tagsJSON) > 0 { diff --git a/internal/repository/postgres/profile.go b/internal/repository/postgres/profile.go new file mode 100644 index 0000000..2544262 --- /dev/null +++ b/internal/repository/postgres/profile.go @@ -0,0 +1,226 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// ProfileRepository implements repository.CertificateProfileRepository +type ProfileRepository struct { + db *sql.DB +} + +// NewProfileRepository creates a new ProfileRepository +func NewProfileRepository(db *sql.DB) *ProfileRepository { + return &ProfileRepository{db: db} +} + +// List returns all certificate profiles +func (r *ProfileRepository) List(ctx context.Context) ([]*domain.CertificateProfile, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + FROM certificate_profiles + ORDER BY created_at DESC + `) + if err != nil { + return nil, fmt.Errorf("failed to query profiles: %w", err) + } + defer rows.Close() + + var profiles []*domain.CertificateProfile + for rows.Next() { + p, err := scanProfile(rows) + if err != nil { + return nil, err + } + profiles = append(profiles, p) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating profile rows: %w", err) + } + + return profiles, nil +} + +// Get retrieves a certificate profile by ID +func (r *ProfileRepository) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + FROM certificate_profiles + WHERE id = $1 + `, id) + + p, err := scanProfile(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("profile not found") + } + return nil, fmt.Errorf("failed to query profile: %w", err) + } + + return p, nil +} + +// Create stores a new certificate profile +func (r *ProfileRepository) Create(ctx context.Context, profile *domain.CertificateProfile) error { + if profile.ID == "" { + profile.ID = uuid.New().String() + } + if profile.CreatedAt.IsZero() { + profile.CreatedAt = time.Now() + } + if profile.UpdatedAt.IsZero() { + profile.UpdatedAt = time.Now() + } + + algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms) + if err != nil { + return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err) + } + ekuJSON, err := json.Marshal(profile.AllowedEKUs) + if err != nil { + return fmt.Errorf("failed to marshal allowed_ekus: %w", err) + } + sanJSON, err := json.Marshal(profile.RequiredSANPatterns) + if err != nil { + return fmt.Errorf("failed to marshal required_san_patterns: %w", err) + } + + err = r.db.QueryRowContext(ctx, ` + INSERT INTO certificate_profiles ( + id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING id + `, profile.ID, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds, + ekuJSON, sanJSON, profile.SPIFFEURIPattern, + profile.AllowShortLived, profile.Enabled, profile.CreatedAt, profile.UpdatedAt).Scan(&profile.ID) + + if err != nil { + return fmt.Errorf("failed to create profile: %w", err) + } + + return nil +} + +// Update modifies an existing certificate profile +func (r *ProfileRepository) Update(ctx context.Context, profile *domain.CertificateProfile) error { + profile.UpdatedAt = time.Now() + + algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms) + if err != nil { + return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err) + } + ekuJSON, err := json.Marshal(profile.AllowedEKUs) + if err != nil { + return fmt.Errorf("failed to marshal allowed_ekus: %w", err) + } + sanJSON, err := json.Marshal(profile.RequiredSANPatterns) + if err != nil { + return fmt.Errorf("failed to marshal required_san_patterns: %w", err) + } + + result, err := r.db.ExecContext(ctx, ` + UPDATE certificate_profiles SET + name = $1, + description = $2, + allowed_key_algorithms = $3, + max_ttl_seconds = $4, + allowed_ekus = $5, + required_san_patterns = $6, + spiffe_uri_pattern = $7, + allow_short_lived = $8, + enabled = $9, + updated_at = $10 + WHERE id = $11 + `, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds, + ekuJSON, sanJSON, profile.SPIFFEURIPattern, + profile.AllowShortLived, profile.Enabled, profile.UpdatedAt, profile.ID) + + if err != nil { + return fmt.Errorf("failed to update profile: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("profile not found") + } + + return nil +} + +// Delete removes a certificate profile +func (r *ProfileRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM certificate_profiles WHERE id = $1", id) + if err != nil { + return fmt.Errorf("failed to delete profile: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("profile not found") + } + + return nil +} + +// scanProfile scans a certificate profile from a row or rows +func scanProfile(scanner interface { + Scan(...interface{}) error +}) (*domain.CertificateProfile, error) { + var p domain.CertificateProfile + var algJSON, ekuJSON, sanJSON []byte + + err := scanner.Scan( + &p.ID, &p.Name, &p.Description, &algJSON, &p.MaxTTLSeconds, + &ekuJSON, &sanJSON, &p.SPIFFEURIPattern, + &p.AllowShortLived, &p.Enabled, &p.CreatedAt, &p.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan profile: %w", err) + } + + if len(algJSON) > 0 { + if err := json.Unmarshal(algJSON, &p.AllowedKeyAlgorithms); err != nil { + return nil, fmt.Errorf("failed to unmarshal allowed_key_algorithms: %w", err) + } + } else { + p.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms() + } + + if len(ekuJSON) > 0 { + if err := json.Unmarshal(ekuJSON, &p.AllowedEKUs); err != nil { + return nil, fmt.Errorf("failed to unmarshal allowed_ekus: %w", err) + } + } else { + p.AllowedEKUs = domain.DefaultEKUs() + } + + if len(sanJSON) > 0 { + if err := json.Unmarshal(sanJSON, &p.RequiredSANPatterns); err != nil { + return nil, fmt.Errorf("failed to unmarshal required_san_patterns: %w", err) + } + } + + return &p, nil +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index a3c3a3f..0ed898c 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -19,10 +19,11 @@ type Scheduler struct { logger *slog.Logger // Configurable tick intervals - renewalCheckInterval time.Duration - jobProcessorInterval time.Duration - agentHealthCheckInterval time.Duration - notificationProcessInterval time.Duration + renewalCheckInterval time.Duration + jobProcessorInterval time.Duration + agentHealthCheckInterval time.Duration + notificationProcessInterval time.Duration + shortLivedExpiryCheckInterval time.Duration } // NewScheduler creates a new scheduler with configurable intervals. @@ -41,10 +42,11 @@ func NewScheduler( logger: logger, // Default intervals - renewalCheckInterval: 1 * time.Hour, - jobProcessorInterval: 30 * time.Second, - agentHealthCheckInterval: 2 * time.Minute, - notificationProcessInterval: 1 * time.Minute, + renewalCheckInterval: 1 * time.Hour, + jobProcessorInterval: 30 * time.Second, + agentHealthCheckInterval: 2 * time.Minute, + notificationProcessInterval: 1 * time.Minute, + shortLivedExpiryCheckInterval: 30 * time.Second, } } @@ -87,6 +89,7 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} { go s.jobProcessorLoop(ctx) go s.agentHealthCheckLoop(ctx) go s.notificationProcessLoop(ctx) + go s.shortLivedExpiryCheckLoop(ctx) // Wait for context cancellation <-ctx.Done() @@ -225,3 +228,33 @@ func (s *Scheduler) runNotificationProcess(ctx context.Context) { s.logger.Debug("notification processor completed") } } + +// shortLivedExpiryCheckLoop runs every shortLivedExpiryCheckInterval and marks expired +// short-lived certificates. For certs with TTL < 1 hour, expiry IS revocation — +// no CRL/OCSP needed. +func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) { + ticker := time.NewTicker(s.shortLivedExpiryCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.runShortLivedExpiryCheck(ctx) + } + } +} + +// runShortLivedExpiryCheck executes a single short-lived expiry check with error recovery. +func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) { + opCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if err := s.renewalService.ExpireShortLivedCertificates(opCtx); err != nil { + s.logger.Error("short-lived expiry check failed", + "error", err, + "interval", s.shortLivedExpiryCheckInterval.String()) + } else { + s.logger.Debug("short-lived expiry check completed") + } +} diff --git a/internal/service/crypto_validation.go b/internal/service/crypto_validation.go new file mode 100644 index 0000000..75b6d78 --- /dev/null +++ b/internal/service/crypto_validation.go @@ -0,0 +1,85 @@ +package service + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" +) + +// CSRValidationResult contains metadata extracted from a validated CSR. +type CSRValidationResult struct { + KeyAlgorithm string + KeySize int +} + +// ValidateCSRAgainstProfile parses a CSR PEM and validates that its key algorithm +// and size comply with the profile's allowed_key_algorithms rules. +// Returns extracted key metadata on success for storage in certificate_versions. +func ValidateCSRAgainstProfile(csrPEM string, profile *domain.CertificateProfile) (*CSRValidationResult, error) { + if profile == nil { + // No profile assigned — skip validation, extract metadata only + return extractCSRKeyInfo(csrPEM) + } + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + return nil, err + } + + // Check that the CSR's key algorithm + size matches at least one allowed rule + if len(profile.AllowedKeyAlgorithms) == 0 { + // No restrictions defined — allow anything + return result, nil + } + + for _, rule := range profile.AllowedKeyAlgorithms { + if rule.Algorithm == result.KeyAlgorithm && result.KeySize >= rule.MinSize { + return result, nil + } + } + + return nil, fmt.Errorf("CSR key (%s %d-bit) does not match any allowed algorithm in profile %q: %v", + result.KeyAlgorithm, result.KeySize, profile.Name, profile.AllowedKeyAlgorithms) +} + +// extractCSRKeyInfo parses a CSR PEM and extracts the key algorithm and size. +func extractCSRKeyInfo(csrPEM string) (*CSRValidationResult, error) { + block, _ := pem.Decode([]byte(csrPEM)) + if block == nil { + return nil, fmt.Errorf("failed to decode CSR PEM") + } + + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CSR: %w", err) + } + + if err := csr.CheckSignature(); err != nil { + return nil, fmt.Errorf("CSR signature verification failed: %w", err) + } + + switch key := csr.PublicKey.(type) { + case *rsa.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmRSA, + KeySize: key.N.BitLen(), + }, nil + case *ecdsa.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmECDSA, + KeySize: key.Curve.Params().BitSize, + }, nil + case ed25519.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmEd25519, + KeySize: 256, // Ed25519 is fixed 256-bit + }, nil + default: + return nil, fmt.Errorf("unsupported key type in CSR: %T", csr.PublicKey) + } +} diff --git a/internal/service/crypto_validation_test.go b/internal/service/crypto_validation_test.go new file mode 100644 index 0000000..fcb722c --- /dev/null +++ b/internal/service/crypto_validation_test.go @@ -0,0 +1,244 @@ +package service + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "testing" + + "github.com/shankar0123/certctl/internal/domain" +) + +// generateTestCSR creates a valid CSR PEM for testing purposes. +func generateTestCSR(t *testing.T, keyType string, keySize int) string { + t.Helper() + + var privKey interface{} + var err error + + switch keyType { + case "RSA": + privKey, err = rsa.GenerateKey(rand.Reader, keySize) + case "ECDSA": + var curve elliptic.Curve + switch keySize { + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + default: + t.Fatalf("unsupported ECDSA key size: %d", keySize) + } + privKey, err = ecdsa.GenerateKey(curve, rand.Reader) + default: + t.Fatalf("unsupported key type: %s", keyType) + } + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "test.example.com", + }, + DNSNames: []string{"test.example.com", "www.example.com"}, + } + + csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + t.Fatalf("failed to create CSR: %v", err) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrDER, + }) + + return string(csrPEM) +} + +func TestValidateCSRAgainstProfile_NilProfile(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + result, err := ValidateCSRAgainstProfile(csrPEM, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA256_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA384_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 384) + + profile := &domain.CertificateProfile{ + Name: "High Security", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 384}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeySize != 384 { + t.Errorf("expected 384, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_RSA2048_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + profile := &domain.CertificateProfile{ + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "RSA", MinSize: 2048}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "RSA" { + t.Errorf("expected RSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 2048 { + t.Errorf("expected 2048, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA256_RejectedByHighSecurity(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "High Security", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 384}, + {Algorithm: "RSA", MinSize: 4096}, + }, + } + + _, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err == nil { + t.Fatal("expected rejection, got nil error") + } + if !containsSubstring(err.Error(), "does not match any allowed algorithm") { + t.Errorf("unexpected error message: %s", err.Error()) + } +} + +func TestValidateCSRAgainstProfile_RSA_RejectedByECDSAOnly(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + profile := &domain.CertificateProfile{ + Name: "ECDSA Only", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + }, + } + + _, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err == nil { + t.Fatal("expected rejection, got nil error") + } +} + +func TestValidateCSRAgainstProfile_EmptyAlgorithmRules(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "Permissive", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{}, // empty = allow anything + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } +} + +func TestValidateCSRAgainstProfile_InvalidPEM(t *testing.T) { + _, err := ValidateCSRAgainstProfile("not a pem", nil) + if err == nil { + t.Fatal("expected error for invalid PEM, got nil") + } + if !containsSubstring(err.Error(), "failed to decode CSR PEM") { + t.Errorf("unexpected error: %s", err.Error()) + } +} + +func TestValidateCSRAgainstProfile_InvalidCSRContent(t *testing.T) { + // Valid PEM block but garbage content + csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nTm90IGEgcmVhbCBDU1I=\n-----END CERTIFICATE REQUEST-----" + + _, err := ValidateCSRAgainstProfile(csrPEM, nil) + if err == nil { + t.Fatal("expected error for invalid CSR content, got nil") + } +} + +func TestExtractCSRKeyInfo_ECDSA(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestExtractCSRKeyInfo_RSA(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "RSA" { + t.Errorf("expected RSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 2048 { + t.Errorf("expected 2048, got %d", result.KeySize) + } +} diff --git a/internal/service/job_test.go b/internal/service/job_test.go index 488c03b..29c8dcd 100644 --- a/internal/service/job_test.go +++ b/internal/service/job_test.go @@ -28,7 +28,7 @@ func newTestJobService(jobRepo *mockJobRepo) *JobService { targetRepo := &mockTargetRepo{Targets: make(map[string]*domain.DeploymentTarget)} agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent)} - renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notifService, make(map[string]IssuerConnector), "server") + renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notifService, make(map[string]IssuerConnector), "server") deploymentService := NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notifService) return NewJobService(jobRepo, renewalService, deploymentService, logger) diff --git a/internal/service/profile.go b/internal/service/profile.go new file mode 100644 index 0000000..1b7f8bc --- /dev/null +++ b/internal/service/profile.go @@ -0,0 +1,181 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// ProfileService provides business logic for certificate profile management. +type ProfileService struct { + profileRepo repository.CertificateProfileRepository + auditService *AuditService +} + +// NewProfileService creates a new profile service. +func NewProfileService( + profileRepo repository.CertificateProfileRepository, + auditService *AuditService, +) *ProfileService { + return &ProfileService{ + profileRepo: profileRepo, + auditService: auditService, + } +} + +// ListProfiles returns all profiles (handler interface method). +func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + profiles, err := s.profileRepo.List(context.Background()) + if err != nil { + return nil, 0, fmt.Errorf("failed to list profiles: %w", err) + } + total := int64(len(profiles)) + + var result []domain.CertificateProfile + for _, p := range profiles { + if p != nil { + result = append(result, *p) + } + } + + return result, total, nil +} + +// GetProfile returns a single profile (handler interface method). +func (s *ProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + return s.profileRepo.Get(context.Background(), id) +} + +// CreateProfile creates a new profile with validation (handler interface method). +func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if err := validateProfile(&profile); err != nil { + return nil, err + } + + if profile.ID == "" { + profile.ID = generateID("prof") + } + now := time.Now() + if profile.CreatedAt.IsZero() { + profile.CreatedAt = now + } + if profile.UpdatedAt.IsZero() { + profile.UpdatedAt = now + } + + // Apply defaults if not set + if len(profile.AllowedKeyAlgorithms) == 0 { + profile.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms() + } + if len(profile.AllowedEKUs) == 0 { + profile.AllowedEKUs = domain.DefaultEKUs() + } + + if err := s.profileRepo.Create(context.Background(), &profile); err != nil { + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "create_profile", "certificate_profile", profile.ID, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return &profile, nil +} + +// UpdateProfile modifies an existing profile (handler interface method). +func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if err := validateProfile(&profile); err != nil { + return nil, err + } + + profile.ID = id + if err := s.profileRepo.Update(context.Background(), &profile); err != nil { + return nil, fmt.Errorf("failed to update profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "update_profile", "certificate_profile", id, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return &profile, nil +} + +// DeleteProfile removes a profile (handler interface method). +func (s *ProfileService) DeleteProfile(id string) error { + if err := s.profileRepo.Delete(context.Background(), id); err != nil { + return fmt.Errorf("failed to delete profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "delete_profile", "certificate_profile", id, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return nil +} + +// Get retrieves a profile by ID (used by other services like RenewalService). +func (s *ProfileService) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + return s.profileRepo.Get(ctx, id) +} + +// validateProfile checks that a profile's configuration is valid. +func validateProfile(p *domain.CertificateProfile) error { + if p.Name == "" { + return fmt.Errorf("profile name is required") + } + if len(p.Name) > 255 { + return fmt.Errorf("profile name exceeds 255 characters") + } + + // Validate key algorithms + for _, alg := range p.AllowedKeyAlgorithms { + if !domain.ValidKeyAlgorithms[alg.Algorithm] { + return fmt.Errorf("invalid key algorithm: %s (allowed: RSA, ECDSA, Ed25519)", alg.Algorithm) + } + if alg.Algorithm == domain.KeyAlgorithmRSA && alg.MinSize < 2048 { + return fmt.Errorf("RSA minimum key size must be at least 2048, got %d", alg.MinSize) + } + if alg.Algorithm == domain.KeyAlgorithmECDSA && alg.MinSize < 256 { + return fmt.Errorf("ECDSA minimum key size must be at least 256, got %d", alg.MinSize) + } + } + + // Validate EKUs + for _, eku := range p.AllowedEKUs { + if !domain.ValidEKUs[eku] { + return fmt.Errorf("invalid EKU: %s", eku) + } + } + + // Validate max TTL + if p.MaxTTLSeconds < 0 { + return fmt.Errorf("max_ttl_seconds cannot be negative") + } + + // Validate short-lived consistency + if p.AllowShortLived && p.MaxTTLSeconds >= 3600 { + return fmt.Errorf("allow_short_lived is true but max_ttl_seconds (%d) is >= 3600; short-lived certs must have TTL under 1 hour", p.MaxTTLSeconds) + } + + return nil +} diff --git a/internal/service/profile_test.go b/internal/service/profile_test.go new file mode 100644 index 0000000..53b3f4e --- /dev/null +++ b/internal/service/profile_test.go @@ -0,0 +1,415 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/shankar0123/certctl/internal/domain" +) + +// mockProfileRepo is a test implementation of CertificateProfileRepository +type mockProfileRepo struct { + profiles map[string]*domain.CertificateProfile + ListErr error + GetErr error + CreateErr error + UpdateErr error + DeleteErr error +} + +func newMockProfileRepository() *mockProfileRepo { + return &mockProfileRepo{ + profiles: make(map[string]*domain.CertificateProfile), + } +} + +func (m *mockProfileRepo) List(ctx context.Context) ([]*domain.CertificateProfile, error) { + if m.ListErr != nil { + return nil, m.ListErr + } + var profiles []*domain.CertificateProfile + for _, p := range m.profiles { + profiles = append(profiles, p) + } + return profiles, nil +} + +func (m *mockProfileRepo) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + if m.GetErr != nil { + return nil, m.GetErr + } + p, ok := m.profiles[id] + if !ok { + return nil, errNotFound + } + return p, nil +} + +func (m *mockProfileRepo) Create(ctx context.Context, profile *domain.CertificateProfile) error { + if m.CreateErr != nil { + return m.CreateErr + } + m.profiles[profile.ID] = profile + return nil +} + +func (m *mockProfileRepo) Update(ctx context.Context, profile *domain.CertificateProfile) error { + if m.UpdateErr != nil { + return m.UpdateErr + } + m.profiles[profile.ID] = profile + return nil +} + +func (m *mockProfileRepo) Delete(ctx context.Context, id string) error { + if m.DeleteErr != nil { + return m.DeleteErr + } + delete(m.profiles, id) + return nil +} + +func (m *mockProfileRepo) AddProfile(p *domain.CertificateProfile) { + m.profiles[p.ID] = p +} + +// --- ProfileService Tests --- + +func TestProfileService_ListProfiles(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS", Enabled: true}) + repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true}) + + svc := NewProfileService(repo, nil) + profiles, total, err := svc.ListProfiles(1, 50) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if total != 2 { + t.Errorf("expected total 2, got %d", total) + } + if len(profiles) != 2 { + t.Errorf("expected 2 profiles, got %d", len(profiles)) + } +} + +func TestProfileService_ListProfiles_Empty(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + profiles, total, err := svc.ListProfiles(1, 50) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if total != 0 { + t.Errorf("expected total 0, got %d", total) + } + if len(profiles) != 0 { + t.Errorf("expected 0 profiles, got %d", len(profiles)) + } +} + +func TestProfileService_ListProfiles_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.ListErr = errors.New("db error") + svc := NewProfileService(repo, nil) + + _, _, err := svc.ListProfiles(1, 50) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_GetProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"}) + svc := NewProfileService(repo, nil) + + profile, err := svc.GetProfile("prof-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if profile.Name != "Standard TLS" { + t.Errorf("expected 'Standard TLS', got '%s'", profile.Name) + } +} + +func TestProfileService_GetProfile_NotFound(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + _, err := svc.GetProfile("nonexistent") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_CreateProfile_Defaults(t *testing.T) { + repo := newMockProfileRepository() + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + profile := domain.CertificateProfile{ + Name: "New Profile", + MaxTTLSeconds: 86400, + } + + created, err := svc.CreateProfile(profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if created.ID == "" { + t.Error("expected generated ID, got empty") + } + if len(created.AllowedKeyAlgorithms) == 0 { + t.Error("expected default key algorithms, got empty") + } + if len(created.AllowedEKUs) == 0 { + t.Error("expected default EKUs, got empty") + } + if created.CreatedAt.IsZero() { + t.Error("expected CreatedAt to be set") + } + // Verify audit event recorded + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_CreateProfile_ValidationErrors(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + tests := []struct { + name string + profile domain.CertificateProfile + errMsg string + }{ + { + name: "empty name", + profile: domain.CertificateProfile{}, + errMsg: "profile name is required", + }, + { + name: "name too long", + profile: domain.CertificateProfile{ + Name: string(make([]byte, 256)), + }, + errMsg: "exceeds 255 characters", + }, + { + name: "invalid key algorithm", + profile: domain.CertificateProfile{ + Name: "Bad Algo", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "DES", MinSize: 56}, + }, + }, + errMsg: "invalid key algorithm", + }, + { + name: "RSA key too small", + profile: domain.CertificateProfile{ + Name: "Weak RSA", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "RSA", MinSize: 1024}, + }, + }, + errMsg: "RSA minimum key size must be at least 2048", + }, + { + name: "ECDSA key too small", + profile: domain.CertificateProfile{ + Name: "Weak ECDSA", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 128}, + }, + }, + errMsg: "ECDSA minimum key size must be at least 256", + }, + { + name: "invalid EKU", + profile: domain.CertificateProfile{ + Name: "Bad EKU", + AllowedEKUs: []string{"invalidEKU"}, + }, + errMsg: "invalid EKU", + }, + { + name: "negative TTL", + profile: domain.CertificateProfile{ + Name: "Negative TTL", + MaxTTLSeconds: -1, + }, + errMsg: "cannot be negative", + }, + { + name: "short-lived with long TTL", + profile: domain.CertificateProfile{ + Name: "Inconsistent Short-Lived", + AllowShortLived: true, + MaxTTLSeconds: 7200, + }, + errMsg: "short-lived certs must have TTL under 1 hour", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := svc.CreateProfile(tt.profile) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.errMsg) + } + if !contains(err.Error(), tt.errMsg) { + t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error()) + } + }) + } +} + +func TestProfileService_CreateProfile_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.CreateErr = errors.New("db create failed") + svc := NewProfileService(repo, nil) + + _, err := svc.CreateProfile(domain.CertificateProfile{Name: "Valid"}) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_UpdateProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Original"}) + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + updated, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{ + Name: "Updated", + MaxTTLSeconds: 43200, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated.ID != "prof-1" { + t.Errorf("expected ID 'prof-1', got '%s'", updated.ID) + } + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_UpdateProfile_ValidationError(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + _, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{Name: ""}) + if err == nil { + t.Fatal("expected validation error, got nil") + } +} + +func TestProfileService_DeleteProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "To Delete"}) + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + err := svc.DeleteProfile("prof-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_DeleteProfile_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.DeleteErr = errors.New("db delete failed") + svc := NewProfileService(repo, nil) + + err := svc.DeleteProfile("prof-1") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_CreateProfile_ValidShortLived(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + // Short-lived with TTL under 1 hour should succeed + created, err := svc.CreateProfile(domain.CertificateProfile{ + Name: "CI Ephemeral", + AllowShortLived: true, + MaxTTLSeconds: 300, // 5 minutes + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !created.AllowShortLived { + t.Error("expected AllowShortLived to be true") + } +} + +func TestIsShortLived(t *testing.T) { + tests := []struct { + name string + profile domain.CertificateProfile + expected bool + }{ + { + name: "short-lived with 5 min TTL", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 300}, + expected: true, + }, + { + name: "short-lived flag false", + profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 300}, + expected: false, + }, + { + name: "zero TTL with flag", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 0}, + expected: false, + }, + { + name: "TTL at 1 hour boundary", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 3600}, + expected: false, + }, + { + name: "standard long-lived", + profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 7776000}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.profile.IsShortLived() + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +// contains checks if a string contains a substring (helper for test assertions). +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/service/renewal.go b/internal/service/renewal.go index 92700eb..f01d3c9 100644 --- a/internal/service/renewal.go +++ b/internal/service/renewal.go @@ -22,6 +22,7 @@ type RenewalService struct { certRepo repository.CertificateRepository jobRepo repository.JobRepository renewalPolicyRepo repository.RenewalPolicyRepository + profileRepo repository.CertificateProfileRepository auditService *AuditService notificationSvc *NotificationService issuerRegistry map[string]IssuerConnector @@ -52,6 +53,7 @@ func NewRenewalService( certRepo repository.CertificateRepository, jobRepo repository.JobRepository, renewalPolicyRepo repository.RenewalPolicyRepository, + profileRepo repository.CertificateProfileRepository, auditService *AuditService, notificationSvc *NotificationService, issuerRegistry map[string]IssuerConnector, @@ -64,6 +66,7 @@ func NewRenewalService( certRepo: certRepo, jobRepo: jobRepo, renewalPolicyRepo: renewalPolicyRepo, + profileRepo: profileRepo, auditService: auditService, notificationSvc: notificationSvc, issuerRegistry: issuerRegistry, @@ -371,6 +374,8 @@ func (s *RenewalService) processRenewalServerKeygen(ctx context.Context, job *do FingerprintSHA256: fingerprint, PEMChain: result.CertPEM + "\n" + result.ChainPEM, CSRPEM: privKeyPEM, // Server mode: stores private key for agent deployment + KeyAlgorithm: domain.KeyAlgorithmRSA, + KeySize: 2048, CreatedAt: time.Now(), } @@ -428,6 +433,22 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai return fmt.Errorf("issuer connector not found for %s", cert.IssuerID) } + // Validate CSR against certificate profile (crypto policy enforcement) + var profile *domain.CertificateProfile + if cert.CertificateProfileID != "" && s.profileRepo != nil { + var profileErr error + profile, profileErr = s.profileRepo.Get(ctx, cert.CertificateProfileID) + if profileErr != nil { + slog.Warn("failed to fetch certificate profile, skipping crypto validation", + "profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", profileErr) + } + } + csrInfo, csrErr := ValidateCSRAgainstProfile(csrPEM, profile) + if csrErr != nil { + s.failJob(ctx, job, fmt.Sprintf("CSR validation failed: %v", csrErr)) + return fmt.Errorf("CSR validation failed: %w", csrErr) + } + // Update job to running if err := s.jobRepo.UpdateStatus(ctx, job.ID, domain.JobStatusRunning, ""); err != nil { return fmt.Errorf("failed to update job status: %w", err) @@ -462,6 +483,10 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai CSRPEM: csrPEM, // Agent mode: stores actual CSR, not private key CreatedAt: time.Now(), } + if csrInfo != nil { + version.KeyAlgorithm = csrInfo.KeyAlgorithm + version.KeySize = csrInfo.KeySize + } if err := s.certRepo.CreateVersion(ctx, version); err != nil { s.failJob(ctx, job, fmt.Sprintf("version creation failed: %v", err)) @@ -589,6 +614,73 @@ func (s *RenewalService) RetryFailedJobs(ctx context.Context, maxRetries int) er return nil } +// ExpireShortLivedCertificates finds active certificates with short-lived profiles +// whose TTL has elapsed and marks them as Expired. For certs with TTL < 1 hour, +// expiry is the revocation mechanism — no CRL/OCSP needed. +func (s *RenewalService) ExpireShortLivedCertificates(ctx context.Context) error { + if s.profileRepo == nil { + return nil + } + + // Get all Active certificates and check if any have expired based on their actual expiry time + // This catches short-lived certs that expire between normal renewal check cycles + now := time.Now() + expiring, err := s.certRepo.GetExpiringCertificates(ctx, now) + if err != nil { + return fmt.Errorf("failed to fetch expired certificates: %w", err) + } + + for _, cert := range expiring { + if cert.Status != domain.CertificateStatusActive && cert.Status != domain.CertificateStatusExpiring { + continue + } + + // Only auto-expire certs that have actually passed their expiry time + if cert.ExpiresAt.After(now) { + continue + } + + // Check if this cert has a short-lived profile + if cert.CertificateProfileID == "" { + continue + } + + profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID) + if err != nil { + slog.Warn("failed to fetch profile for short-lived expiry check", + "profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", err) + continue + } + + if !profile.IsShortLived() { + continue + } + + // Mark as expired + cert.Status = domain.CertificateStatusExpired + cert.UpdatedAt = now + if err := s.certRepo.Update(ctx, cert); err != nil { + slog.Error("failed to expire short-lived cert", "cert_id", cert.ID, "error", err) + continue + } + + slog.Info("short-lived certificate expired (expiry = revocation)", + "cert_id", cert.ID, "profile_id", cert.CertificateProfileID, + "expired_at", cert.ExpiresAt) + + if auditErr := s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem, + "short_lived_cert_expired", "certificate", cert.ID, + map[string]interface{}{ + "profile_id": cert.CertificateProfileID, + "expired_at": cert.ExpiresAt, + }); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return nil +} + // generateID is a helper to generate unique IDs. In production, use a proper ID generator. func generateID(prefix string) string { return fmt.Sprintf("%s-%d", prefix, time.Now().UnixNano()) diff --git a/internal/service/renewal_test.go b/internal/service/renewal_test.go index 464251f..7a8c361 100644 --- a/internal/service/renewal_test.go +++ b/internal/service/renewal_test.go @@ -30,7 +30,7 @@ func TestCheckExpiringCertificates_SendsThresholdAlerts(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create a cert expiring in 10 days cert := &domain.ManagedCertificate{ @@ -112,7 +112,7 @@ func TestCheckExpiringCertificates_DeduplicatesAlerts(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert cert := &domain.ManagedCertificate{ @@ -192,7 +192,7 @@ func TestCheckExpiringCertificates_SkipsRenewalInProgress(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert with RenewalInProgress status cert := &domain.ManagedCertificate{ @@ -257,7 +257,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpiring(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create active cert that will become expiring // Use an issuer NOT in the registry so no renewal job is created (which would override status) @@ -319,7 +319,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpired(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert that is already expired // Use an issuer NOT in the registry so no renewal job is created (which would override status) @@ -381,7 +381,7 @@ func TestCheckExpiringCertificates_CreatesRenewalJob(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create expiring cert with registered issuer cert := &domain.ManagedCertificate{ @@ -447,7 +447,7 @@ func TestCheckExpiringCertificates_SkipsWithoutIssuer(t *testing.T) { // Empty issuer registry issuerRegistry := map[string]IssuerConnector{} - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert with unregistered issuer cert := &domain.ManagedCertificate{ @@ -509,7 +509,7 @@ func TestCheckExpiringCertificates_SkipsDuplicateJobs(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert cert := &domain.ManagedCertificate{ @@ -593,7 +593,7 @@ func TestProcessRenewalJob(t *testing.T) { "iss-test": issuerConnector, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create certificate cert := &domain.ManagedCertificate{ @@ -689,7 +689,7 @@ func TestProcessRenewalJob_IssuerFailure(t *testing.T) { "iss-test": issuerConnector, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create certificate cert := &domain.ManagedCertificate{ @@ -771,7 +771,7 @@ func TestRetryFailedJobs(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create failed job with attempts < max_attempts failedJob := &domain.Job{ @@ -836,7 +836,7 @@ func TestProcessRenewalJob_NoCertificate(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create job with non-existent certificate job := &domain.Job{ diff --git a/migrations/000003_certificate_profiles.down.sql b/migrations/000003_certificate_profiles.down.sql new file mode 100644 index 0000000..d2abe2f --- /dev/null +++ b/migrations/000003_certificate_profiles.down.sql @@ -0,0 +1,13 @@ +-- Rollback: remove certificate profiles and associated columns + +ALTER TABLE certificate_versions DROP COLUMN IF EXISTS key_algorithm; +ALTER TABLE certificate_versions DROP COLUMN IF EXISTS key_size; + +ALTER TABLE renewal_policies DROP COLUMN IF EXISTS certificate_profile_id; + +DROP INDEX IF EXISTS idx_managed_certificates_profile_id; +ALTER TABLE managed_certificates DROP COLUMN IF EXISTS certificate_profile_id; + +DROP INDEX IF EXISTS idx_certificate_profiles_name; +DROP INDEX IF EXISTS idx_certificate_profiles_enabled; +DROP TABLE IF EXISTS certificate_profiles; diff --git a/migrations/000003_certificate_profiles.up.sql b/migrations/000003_certificate_profiles.up.sql new file mode 100644 index 0000000..267fbc7 --- /dev/null +++ b/migrations/000003_certificate_profiles.up.sql @@ -0,0 +1,53 @@ +-- M11a: Certificate Profiles + Crypto Foundation +-- Named enrollment profiles defining allowed key types, max TTL, required SANs, +-- permitted EKUs, and optional SPIFFE URI SAN patterns. + +-- Table: certificate_profiles +CREATE TABLE IF NOT EXISTS certificate_profiles ( + id TEXT PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT DEFAULT '', + + -- Crypto policy: which key algorithms and minimum sizes are allowed + -- Example: [{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}] + allowed_key_algorithms JSONB NOT NULL DEFAULT '[{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}]', + + -- Maximum certificate TTL in seconds (0 = no limit, uses issuer default) + -- Short-lived: 300 (5 min), 3600 (1 hour). Standard: 7776000 (90 days), 4060800 (47 days) + max_ttl_seconds INT NOT NULL DEFAULT 0, + + -- Permitted Extended Key Usages + -- Example: ["serverAuth", "clientAuth"] + allowed_ekus JSONB NOT NULL DEFAULT '["serverAuth"]', + + -- Required SAN patterns (regexes that issued certs must match) + -- Example: [".*\\.example\\.com$", ".*\\.internal\\.example\\.com$"] + required_san_patterns JSONB NOT NULL DEFAULT '[]', + + -- Optional SPIFFE URI SAN pattern for workload identity + -- Example: "spiffe://example.com/workload/*" + -- Empty string means no SPIFFE SAN will be minted + spiffe_uri_pattern VARCHAR(512) DEFAULT '', + + -- Whether this profile allows short-lived certs (TTL < 1 hour) + -- When true, expired certs under this profile skip CRL/OCSP (expiry = revocation) + allow_short_lived BOOLEAN NOT NULL DEFAULT false, + + enabled BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_certificate_profiles_name ON certificate_profiles(name); +CREATE INDEX IF NOT EXISTS idx_certificate_profiles_enabled ON certificate_profiles(enabled); + +-- Add certificate_profile_id FK to managed_certificates (nullable for backward compat) +ALTER TABLE managed_certificates ADD COLUMN IF NOT EXISTS certificate_profile_id TEXT REFERENCES certificate_profiles(id) ON DELETE SET NULL; +CREATE INDEX IF NOT EXISTS idx_managed_certificates_profile_id ON managed_certificates(certificate_profile_id); + +-- Add certificate_profile_id FK to renewal_policies (nullable — profile scoping on policies) +ALTER TABLE renewal_policies ADD COLUMN IF NOT EXISTS certificate_profile_id TEXT REFERENCES certificate_profiles(id) ON DELETE SET NULL; + +-- Add key metadata to certificate_versions for audit / compliance +ALTER TABLE certificate_versions ADD COLUMN IF NOT EXISTS key_algorithm VARCHAR(50) DEFAULT ''; +ALTER TABLE certificate_versions ADD COLUMN IF NOT EXISTS key_size INT DEFAULT 0; diff --git a/migrations/seed_demo.sql b/migrations/seed_demo.sql index df0533b..ddbc5ee 100644 --- a/migrations/seed_demo.sql +++ b/migrations/seed_demo.sql @@ -53,6 +53,42 @@ INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, creat ('tgt-nginx-data', 'NGINX Data Services', 'nginx', 'ag-data-prod', '{"cert_path": "/etc/nginx/ssl/cert.pem", "key_path": "/etc/nginx/ssl/key.pem", "reload_command": "nginx -s reload"}', true, NOW(), NOW()) ON CONFLICT (id) DO NOTHING; +-- Certificate Profiles +INSERT INTO certificate_profiles (id, name, description, allowed_key_algorithms, max_ttl_seconds, allowed_ekus, required_san_patterns, spiffe_uri_pattern, allow_short_lived, enabled, created_at, updated_at) VALUES + ('prof-standard-tls', 'Standard TLS', + 'Default profile for web-facing TLS certificates. Requires ECDSA P-256+ or RSA 2048+.', + '[{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}]'::jsonb, + 7776000, -- 90 days + '["serverAuth"]'::jsonb, + '[]'::jsonb, + '', false, true, NOW(), NOW()), + + ('prof-internal-mtls', 'Internal mTLS', + 'Mutual TLS profile for internal service-to-service communication.', + '[{"algorithm": "ECDSA", "min_size": 256}]'::jsonb, + 2592000, -- 30 days + '["serverAuth", "clientAuth"]'::jsonb, + '[".*\\.internal\\.example\\.com$"]'::jsonb, + '', false, true, NOW(), NOW()), + + ('prof-short-lived', 'Short-Lived Credential', + 'Ephemeral certificates for CI/CD pipelines and container workloads. TTL under 1 hour, expiry = revocation.', + '[{"algorithm": "ECDSA", "min_size": 256}]'::jsonb, + 300, -- 5 minutes + '["serverAuth", "clientAuth"]'::jsonb, + '[]'::jsonb, + 'spiffe://example.com/workload/*', + true, true, NOW(), NOW()), + + ('prof-high-security', 'High Security', + 'For PCI-DSS and compliance-sensitive workloads. RSA 4096+ or ECDSA P-384+ only.', + '[{"algorithm": "ECDSA", "min_size": 384}, {"algorithm": "RSA", "min_size": 4096}]'::jsonb, + 4060800, -- 47 days (Ballot SC-081v3 target) + '["serverAuth"]'::jsonb, + '[".*\\.example\\.com$"]'::jsonb, + '', false, true, NOW(), NOW()) +ON CONFLICT (id) DO NOTHING; + -- Managed Certificates — varied statuses and expiry dates for realistic dashboard INSERT INTO managed_certificates (id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at) VALUES -- Active, healthy certs diff --git a/web/src/api/client.ts b/web/src/api/client.ts index 828dca6..bf7b8b3 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -1,4 +1,4 @@ -import type { Certificate, CertificateVersion, Agent, Job, Notification, AuditEvent, PolicyRule, PolicyViolation, Issuer, Target, PaginatedResponse } from './types'; +import type { Certificate, CertificateVersion, Agent, Job, Notification, AuditEvent, PolicyRule, PolicyViolation, Issuer, Target, CertificateProfile, PaginatedResponse } from './types'; const BASE = '/api/v1'; @@ -169,5 +169,23 @@ export const createTarget = (data: Partial) => export const deleteTarget = (id: string) => fetchJSON<{ message: string }>(`${BASE}/targets/${id}`, { method: 'DELETE' }); +// Profiles +export const getProfiles = (params: Record = {}) => { + const qs = new URLSearchParams({ page: '1', per_page: '50', ...params }).toString(); + return fetchJSON>(`${BASE}/profiles?${qs}`); +}; + +export const getProfile = (id: string) => + fetchJSON(`${BASE}/profiles/${id}`); + +export const createProfile = (data: Partial) => + fetchJSON(`${BASE}/profiles`, { method: 'POST', body: JSON.stringify(data) }); + +export const updateProfile = (id: string, data: Partial) => + fetchJSON(`${BASE}/profiles/${id}`, { method: 'PUT', body: JSON.stringify(data) }); + +export const deleteProfile = (id: string) => + fetchJSON<{ message: string }>(`${BASE}/profiles/${id}`, { method: 'DELETE' }); + // Health export const getHealth = () => fetchJSON<{ status: string }>('/health'); diff --git a/web/src/api/types.ts b/web/src/api/types.ts index 5a6d9ae..0881cf1 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -129,6 +129,26 @@ export interface Target { created_at: string; } +export interface KeyAlgorithmRule { + algorithm: string; + min_size: number; +} + +export interface CertificateProfile { + id: string; + name: string; + description: string; + allowed_key_algorithms: KeyAlgorithmRule[]; + max_ttl_seconds: number; + allowed_ekus: string[]; + required_san_patterns: string[]; + spiffe_uri_pattern: string; + allow_short_lived: boolean; + enabled: boolean; + created_at: string; + updated_at: string; +} + export interface PaginatedResponse { data: T[]; total: number; diff --git a/web/src/components/Layout.tsx b/web/src/components/Layout.tsx index db6fe56..0a806e1 100644 --- a/web/src/components/Layout.tsx +++ b/web/src/components/Layout.tsx @@ -8,6 +8,7 @@ const nav = [ { to: '/jobs', label: 'Jobs', icon: 'M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15' }, { to: '/notifications', label: 'Notifications', icon: 'M15 17h5l-1.405-1.405A2.032 2.032 0 0118 14.158V11a6.002 6.002 0 00-4-5.659V5a2 2 0 10-4 0v.341C7.67 6.165 6 8.388 6 11v3.159c0 .538-.214 1.055-.595 1.436L4 17h5m6 0v1a3 3 0 11-6 0v-1m6 0H9' }, { to: '/policies', label: 'Policies', icon: 'M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4' }, + { to: '/profiles', label: 'Profiles', icon: 'M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.066 2.573c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.573 1.066c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.066-2.573c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z M15 12a3 3 0 11-6 0 3 3 0 016 0z' }, { to: '/issuers', label: 'Issuers', icon: 'M15 7a2 2 0 012 2m4 0a6 6 0 01-7.743 5.743L11 17H9v2H7v2H4a1 1 0 01-1-1v-2.586a1 1 0 01.293-.707l5.964-5.964A6 6 0 1121 9z' }, { to: '/targets', label: 'Targets', icon: 'M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10' }, { to: '/audit', label: 'Audit Trail', icon: 'M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z' }, diff --git a/web/src/main.tsx b/web/src/main.tsx index f820d6d..f68afa6 100644 --- a/web/src/main.tsx +++ b/web/src/main.tsx @@ -16,6 +16,7 @@ import NotificationsPage from './pages/NotificationsPage'; import PoliciesPage from './pages/PoliciesPage'; import IssuersPage from './pages/IssuersPage'; import TargetsPage from './pages/TargetsPage'; +import ProfilesPage from './pages/ProfilesPage'; import AuditPage from './pages/AuditPage'; import './index.css'; @@ -46,6 +47,7 @@ createRoot(document.getElementById('root')!).render( } /> } /> } /> + } /> } /> } /> } /> diff --git a/web/src/pages/ProfilesPage.tsx b/web/src/pages/ProfilesPage.tsx new file mode 100644 index 0000000..5a7e6e9 --- /dev/null +++ b/web/src/pages/ProfilesPage.tsx @@ -0,0 +1,129 @@ +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; +import { getProfiles, deleteProfile } from '../api/client'; +import PageHeader from '../components/PageHeader'; +import DataTable from '../components/DataTable'; +import type { Column } from '../components/DataTable'; +import StatusBadge from '../components/StatusBadge'; +import ErrorState from '../components/ErrorState'; +import { formatDateTime } from '../api/utils'; +import type { CertificateProfile } from '../api/types'; + +function formatTTL(seconds: number): string { + if (seconds === 0) return 'No limit'; + if (seconds < 60) return `${seconds}s`; + if (seconds < 3600) return `${Math.floor(seconds / 60)}m`; + if (seconds < 86400) return `${Math.floor(seconds / 3600)}h`; + return `${Math.floor(seconds / 86400)}d`; +} + +export default function ProfilesPage() { + const queryClient = useQueryClient(); + + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['profiles'], + queryFn: () => getProfiles(), + }); + + const deleteMutation = useMutation({ + mutationFn: deleteProfile, + onSuccess: () => queryClient.invalidateQueries({ queryKey: ['profiles'] }), + }); + + const columns: Column[] = [ + { + key: 'name', + label: 'Profile', + render: (p) => ( +
+
{p.name}
+
{p.id}
+ {p.description && ( +
{p.description}
+ )} +
+ ), + }, + { + key: 'algorithms', + label: 'Key Algorithms', + render: (p) => ( +
+ {(p.allowed_key_algorithms || []).map((alg, i) => ( + + {alg.algorithm} {alg.min_size}+ + + ))} +
+ ), + }, + { + key: 'ttl', + label: 'Max TTL', + render: (p) => ( +
+ {formatTTL(p.max_ttl_seconds)} + {p.allow_short_lived && ( + + short-lived + + )} +
+ ), + }, + { + key: 'ekus', + label: 'EKUs', + render: (p) => ( +
+ {(p.allowed_ekus || []).map((eku, i) => ( + {eku} + ))} +
+ ), + }, + { + key: 'spiffe', + label: 'SPIFFE', + render: (p) => ( + p.spiffe_uri_pattern + ? {p.spiffe_uri_pattern} + : + ), + }, + { + key: 'enabled', + label: 'Status', + render: (p) => , + }, + { + key: 'created', + label: 'Created', + render: (p) => {formatDateTime(p.created_at)}, + }, + { + key: 'actions', + label: '', + render: (p) => ( + + ), + }, + ]; + + return ( + <> + +
+ {error ? ( + refetch()} /> + ) : ( + + )} +
+ + ); +}