diff --git a/cmd/server/main.go b/cmd/server/main.go index e89bd26..4ddd449 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -68,6 +68,7 @@ func main() { policyRepo := postgres.NewPolicyRepository(db) notificationRepo := postgres.NewNotificationRepository(db) renewalPolicyRepo := postgres.NewRenewalPolicyRepository(db) + profileRepo := postgres.NewProfileRepository(db) teamRepo := postgres.NewTeamRepository(db) ownerRepo := postgres.NewOwnerRepository(db) logger.Info("initialized all repositories") @@ -102,12 +103,13 @@ func main() { policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certificateRepo, policyService, auditService) notificationService := service.NewNotificationService(notificationRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode) + renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode) deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certificateRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) issuerService := service.NewIssuerService(issuerRepo, auditService) targetService := service.NewTargetService(targetRepo, auditService) + profileService := service.NewProfileService(profileRepo, auditService) teamService := service.NewTeamService(teamRepo, auditService) ownerService := service.NewOwnerService(ownerRepo, auditService) logger.Info("initialized all services") @@ -119,6 +121,7 @@ func main() { agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(profileService) teamHandler := handler.NewTeamHandler(teamService) ownerHandler := handler.NewOwnerHandler(ownerService) auditHandler := handler.NewAuditHandler(auditService) @@ -160,6 +163,7 @@ func main() { agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, diff --git a/internal/api/handler/profile_handler_test.go b/internal/api/handler/profile_handler_test.go new file mode 100644 index 0000000..a769ba1 --- /dev/null +++ b/internal/api/handler/profile_handler_test.go @@ -0,0 +1,429 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/shankar0123/certctl/internal/domain" +) + +// MockProfileService is a mock implementation of ProfileService interface. +type MockProfileService struct { + ListProfilesFn func(page, perPage int) ([]domain.CertificateProfile, int64, error) + GetProfileFn func(id string) (*domain.CertificateProfile, error) + CreateProfileFn func(profile domain.CertificateProfile) (*domain.CertificateProfile, error) + UpdateProfileFn func(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) + DeleteProfileFn func(id string) error +} + +func (m *MockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + if m.ListProfilesFn != nil { + return m.ListProfilesFn(page, perPage) + } + return nil, 0, nil +} + +func (m *MockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + if m.GetProfileFn != nil { + return m.GetProfileFn(id) + } + return nil, nil +} + +func (m *MockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if m.CreateProfileFn != nil { + return m.CreateProfileFn(profile) + } + return nil, nil +} + +func (m *MockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if m.UpdateProfileFn != nil { + return m.UpdateProfileFn(id, profile) + } + return nil, nil +} + +func (m *MockProfileService) DeleteProfile(id string) error { + if m.DeleteProfileFn != nil { + return m.DeleteProfileFn(id) + } + return nil +} + +func TestListProfiles_Success(t *testing.T) { + now := time.Now() + prof1 := domain.CertificateProfile{ + ID: "prof-standard-tls", + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + }, + MaxTTLSeconds: 7776000, + AllowedEKUs: []string{"serverAuth"}, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } + prof2 := domain.CertificateProfile{ + ID: "prof-internal-mtls", + Name: "Internal mTLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + }, + MaxTTLSeconds: 2592000, + AllowedEKUs: []string{"serverAuth", "clientAuth"}, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } + + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return []domain.CertificateProfile{prof1, prof2}, 2, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp PagedResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Total != 2 { + t.Errorf("expected total 2, got %d", resp.Total) + } +} + +func TestListProfiles_Pagination(t *testing.T) { + var capturedPage, capturedPerPage int + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + capturedPage = page + capturedPerPage = perPage + return []domain.CertificateProfile{}, 0, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles?page=3&per_page=25", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if capturedPage != 3 { + t.Errorf("expected page 3, got %d", capturedPage) + } + if capturedPerPage != 25 { + t.Errorf("expected per_page 25, got %d", capturedPerPage) + } +} + +func TestListProfiles_ServiceError(t *testing.T) { + mock := &MockProfileService{ + ListProfilesFn: func(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return nil, 0, ErrMockServiceFailed + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", w.Code) + } +} + +func TestListProfiles_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles", nil) + w := httptest.NewRecorder() + + handler.ListProfiles(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} + +func TestGetProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + GetProfileFn: func(id string) (*domain.CertificateProfile, error) { + return &domain.CertificateProfile{ + ID: id, + Name: "Standard TLS", + MaxTTLSeconds: 7776000, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + }, nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/prof-standard-tls", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } +} + +func TestGetProfile_NotFound(t *testing.T) { + mock := &MockProfileService{ + GetProfileFn: func(id string) (*domain.CertificateProfile, error) { + return nil, ErrMockNotFound + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/nonexistent", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected status 404, got %d", w.Code) + } +} + +func TestGetProfile_EmptyID(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + CreateProfileFn: func(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = "prof-new" + profile.CreatedAt = now + profile.UpdatedAt = now + return &profile, nil + }, + } + + body := map[string]interface{}{ + "name": "New Profile", + "max_ttl_seconds": 86400, + "allowed_ekus": []string{"serverAuth"}, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestCreateProfile_MissingName(t *testing.T) { + body := map[string]interface{}{ + "max_ttl_seconds": 86400, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_NameTooLong(t *testing.T) { + longName := "" + for i := 0; i < 256; i++ { + longName += "x" + } + body := map[string]interface{}{ + "name": longName, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_InvalidJSON(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPost, "/api/v1/profiles", bytes.NewReader([]byte("{invalid"))) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestCreateProfile_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles", nil) + w := httptest.NewRecorder() + + handler.CreateProfile(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} + +func TestUpdateProfile_Success(t *testing.T) { + now := time.Now() + mock := &MockProfileService{ + UpdateProfileFn: func(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = id + profile.UpdatedAt = now + return &profile, nil + }, + } + + body := map[string]interface{}{ + "name": "Updated Profile", + "max_ttl_seconds": 172800, + } + bodyBytes, _ := json.Marshal(body) + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodPut, "/api/v1/profiles/prof-standard-tls", bytes.NewReader(bodyBytes)) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.UpdateProfile(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d; body: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateProfile_InvalidJSON(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodPut, "/api/v1/profiles/prof-x", bytes.NewReader([]byte("{bad"))) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.UpdateProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestDeleteProfile_Success(t *testing.T) { + var deletedID string + mock := &MockProfileService{ + DeleteProfileFn: func(id string) error { + deletedID = id + return nil + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/prof-standard-tls", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("expected status 204, got %d", w.Code) + } + if deletedID != "prof-standard-tls" { + t.Errorf("expected deleted ID 'prof-standard-tls', got '%s'", deletedID) + } +} + +func TestDeleteProfile_ServiceError(t *testing.T) { + mock := &MockProfileService{ + DeleteProfileFn: func(id string) error { + return ErrMockServiceFailed + }, + } + + handler := NewProfileHandler(mock) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/prof-x", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d", w.Code) + } +} + +func TestDeleteProfile_EmptyID(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/profiles/", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", w.Code) + } +} + +func TestDeleteProfile_MethodNotAllowed(t *testing.T) { + handler := NewProfileHandler(&MockProfileService{}) + req := httptest.NewRequest(http.MethodGet, "/api/v1/profiles/prof-x", nil) + w := httptest.NewRecorder() + + handler.DeleteProfile(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} diff --git a/internal/api/handler/profiles.go b/internal/api/handler/profiles.go new file mode 100644 index 0000000..899e0ac --- /dev/null +++ b/internal/api/handler/profiles.go @@ -0,0 +1,206 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/shankar0123/certctl/internal/api/middleware" + "github.com/shankar0123/certctl/internal/domain" +) + +// ProfileService defines the service interface for certificate profile operations. +type ProfileService interface { + ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) + GetProfile(id string) (*domain.CertificateProfile, error) + CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) + UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) + DeleteProfile(id string) error +} + +// ProfileHandler handles HTTP requests for certificate profile operations. +type ProfileHandler struct { + svc ProfileService +} + +// NewProfileHandler creates a new ProfileHandler with a service dependency. +func NewProfileHandler(svc ProfileService) ProfileHandler { + return ProfileHandler{svc: svc} +} + +// ListProfiles lists all certificate profiles. +// GET /api/v1/profiles?page=1&per_page=50 +func (h ProfileHandler) ListProfiles(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + page := 1 + perPage := 50 + query := r.URL.Query() + if p := query.Get("page"); p != "" { + if parsed, err := strconv.Atoi(p); err == nil && parsed > 0 { + page = parsed + } + } + if pp := query.Get("per_page"); pp != "" { + if parsed, err := strconv.Atoi(pp); err == nil && parsed > 0 && parsed <= 500 { + perPage = parsed + } + } + + profiles, total, err := h.svc.ListProfiles(page, perPage) + if err != nil { + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID) + return + } + + response := PagedResponse{ + Data: profiles, + Total: total, + Page: page, + PerPage: perPage, + } + + JSON(w, http.StatusOK, response) +} + +// GetProfile retrieves a single certificate profile by ID. +// GET /api/v1/profiles/{id} +func (h ProfileHandler) GetProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + if id == "" || strings.Contains(id, "/") { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + + profile, err := h.svc.GetProfile(id) + if err != nil { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + + JSON(w, http.StatusOK, profile) +} + +// CreateProfile creates a new certificate profile. +// POST /api/v1/profiles +func (h ProfileHandler) CreateProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + var profile domain.CertificateProfile + if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, "Invalid request body", requestID) + return + } + + // Validate required fields + if err := ValidateRequired("name", profile.Name); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + if err := ValidateStringLength("name", profile.Name, 255); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + + created, err := h.svc.CreateProfile(profile) + if err != nil { + // Check if it's a validation error from the service + if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") || + strings.Contains(err.Error(), "must be") || strings.Contains(err.Error(), "cannot") { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create profile", requestID) + return + } + + JSON(w, http.StatusCreated, created) +} + +// UpdateProfile updates an existing certificate profile. +// PUT /api/v1/profiles/{id} +func (h ProfileHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + parts := strings.Split(id, "/") + if len(parts) == 0 || parts[0] == "" { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + id = parts[0] + + var profile domain.CertificateProfile + if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { + ErrorWithRequestID(w, http.StatusBadRequest, "Invalid request body", requestID) + return + } + + updated, err := h.svc.UpdateProfile(id, profile) + if err != nil { + if strings.Contains(err.Error(), "not found") { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") || + strings.Contains(err.Error(), "must be") || strings.Contains(err.Error(), "cannot") { + ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update profile", requestID) + return + } + + JSON(w, http.StatusOK, updated) +} + +// DeleteProfile deletes a certificate profile. +// DELETE /api/v1/profiles/{id} +func (h ProfileHandler) DeleteProfile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + id := strings.TrimPrefix(r.URL.Path, "/api/v1/profiles/") + if id == "" || strings.Contains(id, "/") { + ErrorWithRequestID(w, http.StatusBadRequest, "Profile ID is required", requestID) + return + } + + if err := h.svc.DeleteProfile(id); err != nil { + if strings.Contains(err.Error(), "not found") { + ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete profile", requestID) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/router/router.go b/internal/api/router/router.go index 47d985d..a5d416f 100644 --- a/internal/api/router/router.go +++ b/internal/api/router/router.go @@ -51,6 +51,7 @@ func (r *Router) RegisterHandlers( agents handler.AgentHandler, jobs handler.JobHandler, policies handler.PolicyHandler, + profiles handler.ProfileHandler, teams handler.TeamHandler, owners handler.OwnerHandler, audit handler.AuditHandler, @@ -125,6 +126,13 @@ func (r *Router) RegisterHandlers( r.Register("DELETE /api/v1/policies/{id}", http.HandlerFunc(policies.DeletePolicy)) r.Register("GET /api/v1/policies/{id}/violations", http.HandlerFunc(policies.ListViolations)) + // Profiles routes: /api/v1/profiles + r.Register("GET /api/v1/profiles", http.HandlerFunc(profiles.ListProfiles)) + r.Register("POST /api/v1/profiles", http.HandlerFunc(profiles.CreateProfile)) + r.Register("GET /api/v1/profiles/{id}", http.HandlerFunc(profiles.GetProfile)) + r.Register("PUT /api/v1/profiles/{id}", http.HandlerFunc(profiles.UpdateProfile)) + r.Register("DELETE /api/v1/profiles/{id}", http.HandlerFunc(profiles.DeleteProfile)) + // Teams routes: /api/v1/teams r.Register("GET /api/v1/teams", http.HandlerFunc(teams.ListTeams)) r.Register("POST /api/v1/teams", http.HandlerFunc(teams.CreateTeam)) diff --git a/internal/domain/certificate.go b/internal/domain/certificate.go index 0622738..61ea320 100644 --- a/internal/domain/certificate.go +++ b/internal/domain/certificate.go @@ -6,23 +6,24 @@ import ( // ManagedCertificate represents a certificate managed by the control plane. type ManagedCertificate struct { - ID string `json:"id"` - Name string `json:"name"` - CommonName string `json:"common_name"` - SANs []string `json:"sans"` - Environment string `json:"environment"` - OwnerID string `json:"owner_id"` - TeamID string `json:"team_id"` - IssuerID string `json:"issuer_id"` - TargetIDs []string `json:"target_ids"` - RenewalPolicyID string `json:"renewal_policy_id"` - Status CertificateStatus `json:"status"` - ExpiresAt time.Time `json:"expires_at"` - Tags map[string]string `json:"tags"` - LastRenewalAt *time.Time `json:"last_renewal_at,omitempty"` - LastDeploymentAt *time.Time `json:"last_deployment_at,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + CommonName string `json:"common_name"` + SANs []string `json:"sans"` + Environment string `json:"environment"` + OwnerID string `json:"owner_id"` + TeamID string `json:"team_id"` + IssuerID string `json:"issuer_id"` + TargetIDs []string `json:"target_ids"` + RenewalPolicyID string `json:"renewal_policy_id"` + CertificateProfileID string `json:"certificate_profile_id,omitempty"` + Status CertificateStatus `json:"status"` + ExpiresAt time.Time `json:"expires_at"` + Tags map[string]string `json:"tags"` + LastRenewalAt *time.Time `json:"last_renewal_at,omitempty"` + LastDeploymentAt *time.Time `json:"last_deployment_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // CertificateVersion represents a specific version of a certificate. @@ -35,6 +36,8 @@ type CertificateVersion struct { FingerprintSHA256 string `json:"fingerprint_sha256"` PEMChain string `json:"pem_chain"` CSRPEM string `json:"csr_pem"` + KeyAlgorithm string `json:"key_algorithm,omitempty"` + KeySize int `json:"key_size,omitempty"` CreatedAt time.Time `json:"created_at"` } @@ -54,15 +57,16 @@ const ( // RenewalPolicy defines renewal parameters for a managed certificate. type RenewalPolicy struct { - ID string `json:"id"` - Name string `json:"name"` - RenewalWindowDays int `json:"renewal_window_days"` - AutoRenew bool `json:"auto_renew"` - MaxRetries int `json:"max_retries"` - RetryInterval int `json:"retry_interval_seconds"` - AlertThresholdsDays []int `json:"alert_thresholds_days"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + RenewalWindowDays int `json:"renewal_window_days"` + AutoRenew bool `json:"auto_renew"` + MaxRetries int `json:"max_retries"` + RetryInterval int `json:"retry_interval_seconds"` + AlertThresholdsDays []int `json:"alert_thresholds_days"` + CertificateProfileID string `json:"certificate_profile_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // DefaultAlertThresholds returns the standard alert thresholds when none are configured. diff --git a/internal/domain/profile.go b/internal/domain/profile.go new file mode 100644 index 0000000..c89b385 --- /dev/null +++ b/internal/domain/profile.go @@ -0,0 +1,71 @@ +package domain + +import ( + "time" +) + +// CertificateProfile defines an enrollment profile that controls what kinds of +// certificates can be issued: allowed key algorithms, maximum TTL, permitted EKUs, +// required SAN patterns, and optional SPIFFE URI SANs for workload identity. +type CertificateProfile struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + AllowedKeyAlgorithms []KeyAlgorithmRule `json:"allowed_key_algorithms"` + MaxTTLSeconds int `json:"max_ttl_seconds"` + AllowedEKUs []string `json:"allowed_ekus"` + RequiredSANPatterns []string `json:"required_san_patterns"` + SPIFFEURIPattern string `json:"spiffe_uri_pattern"` + AllowShortLived bool `json:"allow_short_lived"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// KeyAlgorithmRule defines an allowed key algorithm and its minimum key size. +type KeyAlgorithmRule struct { + Algorithm string `json:"algorithm"` // "RSA", "ECDSA", "Ed25519" + MinSize int `json:"min_size"` // RSA: 2048/4096, ECDSA: 256/384, Ed25519: 0 (fixed) +} + +// IsShortLived returns true if this profile's max TTL is under 1 hour (3600 seconds). +// Short-lived certs use expiry as revocation — no CRL/OCSP needed. +func (p *CertificateProfile) IsShortLived() bool { + return p.AllowShortLived && p.MaxTTLSeconds > 0 && p.MaxTTLSeconds < 3600 +} + +// DefaultKeyAlgorithms returns sensible defaults for profiles without explicit rules. +func DefaultKeyAlgorithms() []KeyAlgorithmRule { + return []KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + } +} + +// DefaultEKUs returns the default extended key usages. +func DefaultEKUs() []string { + return []string{"serverAuth"} +} + +// Supported key algorithm constants for validation. +const ( + KeyAlgorithmRSA = "RSA" + KeyAlgorithmECDSA = "ECDSA" + KeyAlgorithmEd25519 = "Ed25519" +) + +// ValidKeyAlgorithms is the set of recognized key algorithm names. +var ValidKeyAlgorithms = map[string]bool{ + KeyAlgorithmRSA: true, + KeyAlgorithmECDSA: true, + KeyAlgorithmEd25519: true, +} + +// ValidEKUs is the set of recognized extended key usage names. +var ValidEKUs = map[string]bool{ + "serverAuth": true, + "clientAuth": true, + "codeSigning": true, + "emailProtection": true, + "timeStamping": true, +} diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index 506c557..805b301 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -52,7 +52,7 @@ func TestCertificateLifecycle(t *testing.T) { policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certRepo, policyService, auditService) notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, "server") + renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server") deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) @@ -65,6 +65,7 @@ func TestCertificateLifecycle(t *testing.T) { agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(&mockProfileService{}) teamHandler := handler.NewTeamHandler(&mockTeamService{}) ownerHandler := handler.NewOwnerHandler(&mockOwnerService{}) auditHandler := handler.NewAuditHandler(auditService) @@ -80,6 +81,7 @@ func TestCertificateLifecycle(t *testing.T) { agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, @@ -994,3 +996,26 @@ func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.O func (m *mockOwnerService) DeleteOwner(id string) error { return nil } + +type mockProfileService struct{} + +func (m *mockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + return []domain.CertificateProfile{}, 0, nil +} + +func (m *mockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + return nil, fmt.Errorf("profile not found") +} + +func (m *mockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + return &profile, nil +} + +func (m *mockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + profile.ID = id + return &profile, nil +} + +func (m *mockProfileService) DeleteProfile(id string) error { + return nil +} diff --git a/internal/integration/negative_test.go b/internal/integration/negative_test.go index 04f487e..b385cf4 100644 --- a/internal/integration/negative_test.go +++ b/internal/integration/negative_test.go @@ -43,7 +43,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository policyService := service.NewPolicyService(policyRepo, auditService) certificateService := service.NewCertificateService(certRepo, policyService, auditService) notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier)) - renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry, "server") + renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server") deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService) @@ -55,6 +55,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository agentHandler := handler.NewAgentHandler(agentService) jobHandler := handler.NewJobHandler(jobService) policyHandler := handler.NewPolicyHandler(policyService) + profileHandler := handler.NewProfileHandler(&mockProfileService{}) teamHandler := handler.NewTeamHandler(&mockTeamService{}) ownerHandler := handler.NewOwnerHandler(&mockOwnerService{}) auditHandler := handler.NewAuditHandler(auditService) @@ -69,6 +70,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository agentHandler, jobHandler, policyHandler, + profileHandler, teamHandler, ownerHandler, auditHandler, diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index ed647e7..d988890 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -155,6 +155,20 @@ type TeamRepository interface { Delete(ctx context.Context, id string) error } +// CertificateProfileRepository defines operations for managing certificate profiles. +type CertificateProfileRepository interface { + // List returns all certificate profiles. + List(ctx context.Context) ([]*domain.CertificateProfile, error) + // Get retrieves a certificate profile by ID. + Get(ctx context.Context, id string) (*domain.CertificateProfile, error) + // Create stores a new certificate profile. + Create(ctx context.Context, profile *domain.CertificateProfile) error + // Update modifies an existing certificate profile. + Update(ctx context.Context, profile *domain.CertificateProfile) error + // Delete removes a certificate profile. + Delete(ctx context.Context, id string) error +} + // OwnerRepository defines operations for managing certificate owners. type OwnerRepository interface { // List returns all owners. diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go index cad8154..5d3b50e 100644 --- a/internal/repository/postgres/certificate.go +++ b/internal/repository/postgres/certificate.go @@ -85,7 +85,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer offset := (filter.Page - 1) * filter.PerPage query := fmt.Sprintf(` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates %s ORDER BY created_at DESC @@ -120,7 +120,7 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) { row := r.db.QueryRowContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE id = $1 `, id) @@ -147,14 +147,20 @@ func (r *CertificateRepository) Create(ctx context.Context, cert *domain.Managed return fmt.Errorf("failed to marshal tags: %w", err) } + var profileID *string + if cert.CertificateProfileID != "" { + profileID = &cert.CertificateProfileID + } + err = r.db.QueryRowContext(ctx, ` INSERT INTO managed_certificates ( id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id `, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, - cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, cert.Status, cert.ExpiresAt, + cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID, + cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID) if err != nil { @@ -171,6 +177,11 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed return fmt.Errorf("failed to marshal tags: %w", err) } + var profileID *string + if cert.CertificateProfileID != "" { + profileID = &cert.CertificateProfileID + } + result, err := r.db.ExecContext(ctx, ` UPDATE managed_certificates SET name = $1, @@ -180,15 +191,16 @@ func (r *CertificateRepository) Update(ctx context.Context, cert *domain.Managed owner_id = $5, team_id = $6, issuer_id = $7, - status = $8, - expires_at = $9, - tags = $10, - last_renewal_at = $11, - last_deployment_at = $12, - updated_at = $13 - WHERE id = $14 + certificate_profile_id = $8, + status = $9, + expires_at = $10, + tags = $11, + last_renewal_at = $12, + last_deployment_at = $13, + updated_at = $14 + WHERE id = $15 `, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, - cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt, + cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID) if err != nil { @@ -233,7 +245,7 @@ func (r *CertificateRepository) Archive(ctx context.Context, id string) error { func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, certificate_id, serial_number, not_before, not_after, - fingerprint_sha256, pem_chain, csr_pem, created_at + fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at FROM certificate_versions WHERE certificate_id = $1 ORDER BY created_at DESC @@ -248,7 +260,7 @@ func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) for rows.Next() { var v domain.CertificateVersion if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter, - &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil { + &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil { return nil, fmt.Errorf("failed to scan certificate version: %w", err) } versions = append(versions, &v) @@ -270,11 +282,11 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma err := r.db.QueryRowContext(ctx, ` INSERT INTO certificate_versions ( id, certificate_id, serial_number, not_before, not_after, - fingerprint_sha256, pem_chain, csr_pem, created_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id `, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter, - version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID) + version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID) if err != nil { return fmt.Errorf("failed to create certificate version: %w", err) @@ -287,7 +299,7 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, - status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE expires_at < $1 AND status != $2 ORDER BY expires_at ASC @@ -321,10 +333,12 @@ func scanCertificate(scanner interface { var cert domain.ManagedCertificate var tagsJSON []byte var sans pq.StringArray + var profileID sql.NullString err := scanner.Scan( &cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID, - &cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &cert.Status, &cert.ExpiresAt, &tagsJSON, + &cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID, + &cert.Status, &cert.ExpiresAt, &tagsJSON, &cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt) if err != nil { @@ -332,6 +346,9 @@ func scanCertificate(scanner interface { } cert.SANs = []string(sans) + if profileID.Valid { + cert.CertificateProfileID = profileID.String + } // Unmarshal tags if len(tagsJSON) > 0 { diff --git a/internal/repository/postgres/profile.go b/internal/repository/postgres/profile.go new file mode 100644 index 0000000..2544262 --- /dev/null +++ b/internal/repository/postgres/profile.go @@ -0,0 +1,226 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// ProfileRepository implements repository.CertificateProfileRepository +type ProfileRepository struct { + db *sql.DB +} + +// NewProfileRepository creates a new ProfileRepository +func NewProfileRepository(db *sql.DB) *ProfileRepository { + return &ProfileRepository{db: db} +} + +// List returns all certificate profiles +func (r *ProfileRepository) List(ctx context.Context) ([]*domain.CertificateProfile, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + FROM certificate_profiles + ORDER BY created_at DESC + `) + if err != nil { + return nil, fmt.Errorf("failed to query profiles: %w", err) + } + defer rows.Close() + + var profiles []*domain.CertificateProfile + for rows.Next() { + p, err := scanProfile(rows) + if err != nil { + return nil, err + } + profiles = append(profiles, p) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating profile rows: %w", err) + } + + return profiles, nil +} + +// Get retrieves a certificate profile by ID +func (r *ProfileRepository) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + FROM certificate_profiles + WHERE id = $1 + `, id) + + p, err := scanProfile(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("profile not found") + } + return nil, fmt.Errorf("failed to query profile: %w", err) + } + + return p, nil +} + +// Create stores a new certificate profile +func (r *ProfileRepository) Create(ctx context.Context, profile *domain.CertificateProfile) error { + if profile.ID == "" { + profile.ID = uuid.New().String() + } + if profile.CreatedAt.IsZero() { + profile.CreatedAt = time.Now() + } + if profile.UpdatedAt.IsZero() { + profile.UpdatedAt = time.Now() + } + + algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms) + if err != nil { + return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err) + } + ekuJSON, err := json.Marshal(profile.AllowedEKUs) + if err != nil { + return fmt.Errorf("failed to marshal allowed_ekus: %w", err) + } + sanJSON, err := json.Marshal(profile.RequiredSANPatterns) + if err != nil { + return fmt.Errorf("failed to marshal required_san_patterns: %w", err) + } + + err = r.db.QueryRowContext(ctx, ` + INSERT INTO certificate_profiles ( + id, name, description, allowed_key_algorithms, max_ttl_seconds, + allowed_ekus, required_san_patterns, spiffe_uri_pattern, + allow_short_lived, enabled, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING id + `, profile.ID, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds, + ekuJSON, sanJSON, profile.SPIFFEURIPattern, + profile.AllowShortLived, profile.Enabled, profile.CreatedAt, profile.UpdatedAt).Scan(&profile.ID) + + if err != nil { + return fmt.Errorf("failed to create profile: %w", err) + } + + return nil +} + +// Update modifies an existing certificate profile +func (r *ProfileRepository) Update(ctx context.Context, profile *domain.CertificateProfile) error { + profile.UpdatedAt = time.Now() + + algJSON, err := json.Marshal(profile.AllowedKeyAlgorithms) + if err != nil { + return fmt.Errorf("failed to marshal allowed_key_algorithms: %w", err) + } + ekuJSON, err := json.Marshal(profile.AllowedEKUs) + if err != nil { + return fmt.Errorf("failed to marshal allowed_ekus: %w", err) + } + sanJSON, err := json.Marshal(profile.RequiredSANPatterns) + if err != nil { + return fmt.Errorf("failed to marshal required_san_patterns: %w", err) + } + + result, err := r.db.ExecContext(ctx, ` + UPDATE certificate_profiles SET + name = $1, + description = $2, + allowed_key_algorithms = $3, + max_ttl_seconds = $4, + allowed_ekus = $5, + required_san_patterns = $6, + spiffe_uri_pattern = $7, + allow_short_lived = $8, + enabled = $9, + updated_at = $10 + WHERE id = $11 + `, profile.Name, profile.Description, algJSON, profile.MaxTTLSeconds, + ekuJSON, sanJSON, profile.SPIFFEURIPattern, + profile.AllowShortLived, profile.Enabled, profile.UpdatedAt, profile.ID) + + if err != nil { + return fmt.Errorf("failed to update profile: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("profile not found") + } + + return nil +} + +// Delete removes a certificate profile +func (r *ProfileRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM certificate_profiles WHERE id = $1", id) + if err != nil { + return fmt.Errorf("failed to delete profile: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("profile not found") + } + + return nil +} + +// scanProfile scans a certificate profile from a row or rows +func scanProfile(scanner interface { + Scan(...interface{}) error +}) (*domain.CertificateProfile, error) { + var p domain.CertificateProfile + var algJSON, ekuJSON, sanJSON []byte + + err := scanner.Scan( + &p.ID, &p.Name, &p.Description, &algJSON, &p.MaxTTLSeconds, + &ekuJSON, &sanJSON, &p.SPIFFEURIPattern, + &p.AllowShortLived, &p.Enabled, &p.CreatedAt, &p.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan profile: %w", err) + } + + if len(algJSON) > 0 { + if err := json.Unmarshal(algJSON, &p.AllowedKeyAlgorithms); err != nil { + return nil, fmt.Errorf("failed to unmarshal allowed_key_algorithms: %w", err) + } + } else { + p.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms() + } + + if len(ekuJSON) > 0 { + if err := json.Unmarshal(ekuJSON, &p.AllowedEKUs); err != nil { + return nil, fmt.Errorf("failed to unmarshal allowed_ekus: %w", err) + } + } else { + p.AllowedEKUs = domain.DefaultEKUs() + } + + if len(sanJSON) > 0 { + if err := json.Unmarshal(sanJSON, &p.RequiredSANPatterns); err != nil { + return nil, fmt.Errorf("failed to unmarshal required_san_patterns: %w", err) + } + } + + return &p, nil +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index a3c3a3f..0ed898c 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -19,10 +19,11 @@ type Scheduler struct { logger *slog.Logger // Configurable tick intervals - renewalCheckInterval time.Duration - jobProcessorInterval time.Duration - agentHealthCheckInterval time.Duration - notificationProcessInterval time.Duration + renewalCheckInterval time.Duration + jobProcessorInterval time.Duration + agentHealthCheckInterval time.Duration + notificationProcessInterval time.Duration + shortLivedExpiryCheckInterval time.Duration } // NewScheduler creates a new scheduler with configurable intervals. @@ -41,10 +42,11 @@ func NewScheduler( logger: logger, // Default intervals - renewalCheckInterval: 1 * time.Hour, - jobProcessorInterval: 30 * time.Second, - agentHealthCheckInterval: 2 * time.Minute, - notificationProcessInterval: 1 * time.Minute, + renewalCheckInterval: 1 * time.Hour, + jobProcessorInterval: 30 * time.Second, + agentHealthCheckInterval: 2 * time.Minute, + notificationProcessInterval: 1 * time.Minute, + shortLivedExpiryCheckInterval: 30 * time.Second, } } @@ -87,6 +89,7 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} { go s.jobProcessorLoop(ctx) go s.agentHealthCheckLoop(ctx) go s.notificationProcessLoop(ctx) + go s.shortLivedExpiryCheckLoop(ctx) // Wait for context cancellation <-ctx.Done() @@ -225,3 +228,33 @@ func (s *Scheduler) runNotificationProcess(ctx context.Context) { s.logger.Debug("notification processor completed") } } + +// shortLivedExpiryCheckLoop runs every shortLivedExpiryCheckInterval and marks expired +// short-lived certificates. For certs with TTL < 1 hour, expiry IS revocation — +// no CRL/OCSP needed. +func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) { + ticker := time.NewTicker(s.shortLivedExpiryCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.runShortLivedExpiryCheck(ctx) + } + } +} + +// runShortLivedExpiryCheck executes a single short-lived expiry check with error recovery. +func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) { + opCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if err := s.renewalService.ExpireShortLivedCertificates(opCtx); err != nil { + s.logger.Error("short-lived expiry check failed", + "error", err, + "interval", s.shortLivedExpiryCheckInterval.String()) + } else { + s.logger.Debug("short-lived expiry check completed") + } +} diff --git a/internal/service/crypto_validation.go b/internal/service/crypto_validation.go new file mode 100644 index 0000000..75b6d78 --- /dev/null +++ b/internal/service/crypto_validation.go @@ -0,0 +1,85 @@ +package service + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" +) + +// CSRValidationResult contains metadata extracted from a validated CSR. +type CSRValidationResult struct { + KeyAlgorithm string + KeySize int +} + +// ValidateCSRAgainstProfile parses a CSR PEM and validates that its key algorithm +// and size comply with the profile's allowed_key_algorithms rules. +// Returns extracted key metadata on success for storage in certificate_versions. +func ValidateCSRAgainstProfile(csrPEM string, profile *domain.CertificateProfile) (*CSRValidationResult, error) { + if profile == nil { + // No profile assigned — skip validation, extract metadata only + return extractCSRKeyInfo(csrPEM) + } + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + return nil, err + } + + // Check that the CSR's key algorithm + size matches at least one allowed rule + if len(profile.AllowedKeyAlgorithms) == 0 { + // No restrictions defined — allow anything + return result, nil + } + + for _, rule := range profile.AllowedKeyAlgorithms { + if rule.Algorithm == result.KeyAlgorithm && result.KeySize >= rule.MinSize { + return result, nil + } + } + + return nil, fmt.Errorf("CSR key (%s %d-bit) does not match any allowed algorithm in profile %q: %v", + result.KeyAlgorithm, result.KeySize, profile.Name, profile.AllowedKeyAlgorithms) +} + +// extractCSRKeyInfo parses a CSR PEM and extracts the key algorithm and size. +func extractCSRKeyInfo(csrPEM string) (*CSRValidationResult, error) { + block, _ := pem.Decode([]byte(csrPEM)) + if block == nil { + return nil, fmt.Errorf("failed to decode CSR PEM") + } + + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse CSR: %w", err) + } + + if err := csr.CheckSignature(); err != nil { + return nil, fmt.Errorf("CSR signature verification failed: %w", err) + } + + switch key := csr.PublicKey.(type) { + case *rsa.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmRSA, + KeySize: key.N.BitLen(), + }, nil + case *ecdsa.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmECDSA, + KeySize: key.Curve.Params().BitSize, + }, nil + case ed25519.PublicKey: + return &CSRValidationResult{ + KeyAlgorithm: domain.KeyAlgorithmEd25519, + KeySize: 256, // Ed25519 is fixed 256-bit + }, nil + default: + return nil, fmt.Errorf("unsupported key type in CSR: %T", csr.PublicKey) + } +} diff --git a/internal/service/crypto_validation_test.go b/internal/service/crypto_validation_test.go new file mode 100644 index 0000000..fcb722c --- /dev/null +++ b/internal/service/crypto_validation_test.go @@ -0,0 +1,244 @@ +package service + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "testing" + + "github.com/shankar0123/certctl/internal/domain" +) + +// generateTestCSR creates a valid CSR PEM for testing purposes. +func generateTestCSR(t *testing.T, keyType string, keySize int) string { + t.Helper() + + var privKey interface{} + var err error + + switch keyType { + case "RSA": + privKey, err = rsa.GenerateKey(rand.Reader, keySize) + case "ECDSA": + var curve elliptic.Curve + switch keySize { + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + default: + t.Fatalf("unsupported ECDSA key size: %d", keySize) + } + privKey, err = ecdsa.GenerateKey(curve, rand.Reader) + default: + t.Fatalf("unsupported key type: %s", keyType) + } + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "test.example.com", + }, + DNSNames: []string{"test.example.com", "www.example.com"}, + } + + csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + t.Fatalf("failed to create CSR: %v", err) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrDER, + }) + + return string(csrPEM) +} + +func TestValidateCSRAgainstProfile_NilProfile(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + result, err := ValidateCSRAgainstProfile(csrPEM, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA256_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + {Algorithm: "RSA", MinSize: 2048}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA384_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 384) + + profile := &domain.CertificateProfile{ + Name: "High Security", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 384}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeySize != 384 { + t.Errorf("expected 384, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_RSA2048_Allowed(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + profile := &domain.CertificateProfile{ + Name: "Standard TLS", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "RSA", MinSize: 2048}, + }, + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "RSA" { + t.Errorf("expected RSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 2048 { + t.Errorf("expected 2048, got %d", result.KeySize) + } +} + +func TestValidateCSRAgainstProfile_ECDSA256_RejectedByHighSecurity(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "High Security", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 384}, + {Algorithm: "RSA", MinSize: 4096}, + }, + } + + _, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err == nil { + t.Fatal("expected rejection, got nil error") + } + if !containsSubstring(err.Error(), "does not match any allowed algorithm") { + t.Errorf("unexpected error message: %s", err.Error()) + } +} + +func TestValidateCSRAgainstProfile_RSA_RejectedByECDSAOnly(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + profile := &domain.CertificateProfile{ + Name: "ECDSA Only", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 256}, + }, + } + + _, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err == nil { + t.Fatal("expected rejection, got nil error") + } +} + +func TestValidateCSRAgainstProfile_EmptyAlgorithmRules(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + profile := &domain.CertificateProfile{ + Name: "Permissive", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{}, // empty = allow anything + } + + result, err := ValidateCSRAgainstProfile(csrPEM, profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } +} + +func TestValidateCSRAgainstProfile_InvalidPEM(t *testing.T) { + _, err := ValidateCSRAgainstProfile("not a pem", nil) + if err == nil { + t.Fatal("expected error for invalid PEM, got nil") + } + if !containsSubstring(err.Error(), "failed to decode CSR PEM") { + t.Errorf("unexpected error: %s", err.Error()) + } +} + +func TestValidateCSRAgainstProfile_InvalidCSRContent(t *testing.T) { + // Valid PEM block but garbage content + csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nTm90IGEgcmVhbCBDU1I=\n-----END CERTIFICATE REQUEST-----" + + _, err := ValidateCSRAgainstProfile(csrPEM, nil) + if err == nil { + t.Fatal("expected error for invalid CSR content, got nil") + } +} + +func TestExtractCSRKeyInfo_ECDSA(t *testing.T) { + csrPEM := generateTestCSR(t, "ECDSA", 256) + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "ECDSA" { + t.Errorf("expected ECDSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 256 { + t.Errorf("expected 256, got %d", result.KeySize) + } +} + +func TestExtractCSRKeyInfo_RSA(t *testing.T) { + csrPEM := generateTestCSR(t, "RSA", 2048) + + result, err := extractCSRKeyInfo(csrPEM) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.KeyAlgorithm != "RSA" { + t.Errorf("expected RSA, got %s", result.KeyAlgorithm) + } + if result.KeySize != 2048 { + t.Errorf("expected 2048, got %d", result.KeySize) + } +} diff --git a/internal/service/job_test.go b/internal/service/job_test.go index 488c03b..29c8dcd 100644 --- a/internal/service/job_test.go +++ b/internal/service/job_test.go @@ -28,7 +28,7 @@ func newTestJobService(jobRepo *mockJobRepo) *JobService { targetRepo := &mockTargetRepo{Targets: make(map[string]*domain.DeploymentTarget)} agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent)} - renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notifService, make(map[string]IssuerConnector), "server") + renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notifService, make(map[string]IssuerConnector), "server") deploymentService := NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notifService) return NewJobService(jobRepo, renewalService, deploymentService, logger) diff --git a/internal/service/profile.go b/internal/service/profile.go new file mode 100644 index 0000000..1b7f8bc --- /dev/null +++ b/internal/service/profile.go @@ -0,0 +1,181 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// ProfileService provides business logic for certificate profile management. +type ProfileService struct { + profileRepo repository.CertificateProfileRepository + auditService *AuditService +} + +// NewProfileService creates a new profile service. +func NewProfileService( + profileRepo repository.CertificateProfileRepository, + auditService *AuditService, +) *ProfileService { + return &ProfileService{ + profileRepo: profileRepo, + auditService: auditService, + } +} + +// ListProfiles returns all profiles (handler interface method). +func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + profiles, err := s.profileRepo.List(context.Background()) + if err != nil { + return nil, 0, fmt.Errorf("failed to list profiles: %w", err) + } + total := int64(len(profiles)) + + var result []domain.CertificateProfile + for _, p := range profiles { + if p != nil { + result = append(result, *p) + } + } + + return result, total, nil +} + +// GetProfile returns a single profile (handler interface method). +func (s *ProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { + return s.profileRepo.Get(context.Background(), id) +} + +// CreateProfile creates a new profile with validation (handler interface method). +func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if err := validateProfile(&profile); err != nil { + return nil, err + } + + if profile.ID == "" { + profile.ID = generateID("prof") + } + now := time.Now() + if profile.CreatedAt.IsZero() { + profile.CreatedAt = now + } + if profile.UpdatedAt.IsZero() { + profile.UpdatedAt = now + } + + // Apply defaults if not set + if len(profile.AllowedKeyAlgorithms) == 0 { + profile.AllowedKeyAlgorithms = domain.DefaultKeyAlgorithms() + } + if len(profile.AllowedEKUs) == 0 { + profile.AllowedEKUs = domain.DefaultEKUs() + } + + if err := s.profileRepo.Create(context.Background(), &profile); err != nil { + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "create_profile", "certificate_profile", profile.ID, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return &profile, nil +} + +// UpdateProfile modifies an existing profile (handler interface method). +func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { + if err := validateProfile(&profile); err != nil { + return nil, err + } + + profile.ID = id + if err := s.profileRepo.Update(context.Background(), &profile); err != nil { + return nil, fmt.Errorf("failed to update profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "update_profile", "certificate_profile", id, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return &profile, nil +} + +// DeleteProfile removes a profile (handler interface method). +func (s *ProfileService) DeleteProfile(id string) error { + if err := s.profileRepo.Delete(context.Background(), id); err != nil { + return fmt.Errorf("failed to delete profile: %w", err) + } + + if s.auditService != nil { + if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + "delete_profile", "certificate_profile", id, nil); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return nil +} + +// Get retrieves a profile by ID (used by other services like RenewalService). +func (s *ProfileService) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + return s.profileRepo.Get(ctx, id) +} + +// validateProfile checks that a profile's configuration is valid. +func validateProfile(p *domain.CertificateProfile) error { + if p.Name == "" { + return fmt.Errorf("profile name is required") + } + if len(p.Name) > 255 { + return fmt.Errorf("profile name exceeds 255 characters") + } + + // Validate key algorithms + for _, alg := range p.AllowedKeyAlgorithms { + if !domain.ValidKeyAlgorithms[alg.Algorithm] { + return fmt.Errorf("invalid key algorithm: %s (allowed: RSA, ECDSA, Ed25519)", alg.Algorithm) + } + if alg.Algorithm == domain.KeyAlgorithmRSA && alg.MinSize < 2048 { + return fmt.Errorf("RSA minimum key size must be at least 2048, got %d", alg.MinSize) + } + if alg.Algorithm == domain.KeyAlgorithmECDSA && alg.MinSize < 256 { + return fmt.Errorf("ECDSA minimum key size must be at least 256, got %d", alg.MinSize) + } + } + + // Validate EKUs + for _, eku := range p.AllowedEKUs { + if !domain.ValidEKUs[eku] { + return fmt.Errorf("invalid EKU: %s", eku) + } + } + + // Validate max TTL + if p.MaxTTLSeconds < 0 { + return fmt.Errorf("max_ttl_seconds cannot be negative") + } + + // Validate short-lived consistency + if p.AllowShortLived && p.MaxTTLSeconds >= 3600 { + return fmt.Errorf("allow_short_lived is true but max_ttl_seconds (%d) is >= 3600; short-lived certs must have TTL under 1 hour", p.MaxTTLSeconds) + } + + return nil +} diff --git a/internal/service/profile_test.go b/internal/service/profile_test.go new file mode 100644 index 0000000..53b3f4e --- /dev/null +++ b/internal/service/profile_test.go @@ -0,0 +1,415 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/shankar0123/certctl/internal/domain" +) + +// mockProfileRepo is a test implementation of CertificateProfileRepository +type mockProfileRepo struct { + profiles map[string]*domain.CertificateProfile + ListErr error + GetErr error + CreateErr error + UpdateErr error + DeleteErr error +} + +func newMockProfileRepository() *mockProfileRepo { + return &mockProfileRepo{ + profiles: make(map[string]*domain.CertificateProfile), + } +} + +func (m *mockProfileRepo) List(ctx context.Context) ([]*domain.CertificateProfile, error) { + if m.ListErr != nil { + return nil, m.ListErr + } + var profiles []*domain.CertificateProfile + for _, p := range m.profiles { + profiles = append(profiles, p) + } + return profiles, nil +} + +func (m *mockProfileRepo) Get(ctx context.Context, id string) (*domain.CertificateProfile, error) { + if m.GetErr != nil { + return nil, m.GetErr + } + p, ok := m.profiles[id] + if !ok { + return nil, errNotFound + } + return p, nil +} + +func (m *mockProfileRepo) Create(ctx context.Context, profile *domain.CertificateProfile) error { + if m.CreateErr != nil { + return m.CreateErr + } + m.profiles[profile.ID] = profile + return nil +} + +func (m *mockProfileRepo) Update(ctx context.Context, profile *domain.CertificateProfile) error { + if m.UpdateErr != nil { + return m.UpdateErr + } + m.profiles[profile.ID] = profile + return nil +} + +func (m *mockProfileRepo) Delete(ctx context.Context, id string) error { + if m.DeleteErr != nil { + return m.DeleteErr + } + delete(m.profiles, id) + return nil +} + +func (m *mockProfileRepo) AddProfile(p *domain.CertificateProfile) { + m.profiles[p.ID] = p +} + +// --- ProfileService Tests --- + +func TestProfileService_ListProfiles(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS", Enabled: true}) + repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true}) + + svc := NewProfileService(repo, nil) + profiles, total, err := svc.ListProfiles(1, 50) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if total != 2 { + t.Errorf("expected total 2, got %d", total) + } + if len(profiles) != 2 { + t.Errorf("expected 2 profiles, got %d", len(profiles)) + } +} + +func TestProfileService_ListProfiles_Empty(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + profiles, total, err := svc.ListProfiles(1, 50) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if total != 0 { + t.Errorf("expected total 0, got %d", total) + } + if len(profiles) != 0 { + t.Errorf("expected 0 profiles, got %d", len(profiles)) + } +} + +func TestProfileService_ListProfiles_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.ListErr = errors.New("db error") + svc := NewProfileService(repo, nil) + + _, _, err := svc.ListProfiles(1, 50) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_GetProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"}) + svc := NewProfileService(repo, nil) + + profile, err := svc.GetProfile("prof-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if profile.Name != "Standard TLS" { + t.Errorf("expected 'Standard TLS', got '%s'", profile.Name) + } +} + +func TestProfileService_GetProfile_NotFound(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + _, err := svc.GetProfile("nonexistent") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_CreateProfile_Defaults(t *testing.T) { + repo := newMockProfileRepository() + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + profile := domain.CertificateProfile{ + Name: "New Profile", + MaxTTLSeconds: 86400, + } + + created, err := svc.CreateProfile(profile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if created.ID == "" { + t.Error("expected generated ID, got empty") + } + if len(created.AllowedKeyAlgorithms) == 0 { + t.Error("expected default key algorithms, got empty") + } + if len(created.AllowedEKUs) == 0 { + t.Error("expected default EKUs, got empty") + } + if created.CreatedAt.IsZero() { + t.Error("expected CreatedAt to be set") + } + // Verify audit event recorded + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_CreateProfile_ValidationErrors(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + tests := []struct { + name string + profile domain.CertificateProfile + errMsg string + }{ + { + name: "empty name", + profile: domain.CertificateProfile{}, + errMsg: "profile name is required", + }, + { + name: "name too long", + profile: domain.CertificateProfile{ + Name: string(make([]byte, 256)), + }, + errMsg: "exceeds 255 characters", + }, + { + name: "invalid key algorithm", + profile: domain.CertificateProfile{ + Name: "Bad Algo", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "DES", MinSize: 56}, + }, + }, + errMsg: "invalid key algorithm", + }, + { + name: "RSA key too small", + profile: domain.CertificateProfile{ + Name: "Weak RSA", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "RSA", MinSize: 1024}, + }, + }, + errMsg: "RSA minimum key size must be at least 2048", + }, + { + name: "ECDSA key too small", + profile: domain.CertificateProfile{ + Name: "Weak ECDSA", + AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{ + {Algorithm: "ECDSA", MinSize: 128}, + }, + }, + errMsg: "ECDSA minimum key size must be at least 256", + }, + { + name: "invalid EKU", + profile: domain.CertificateProfile{ + Name: "Bad EKU", + AllowedEKUs: []string{"invalidEKU"}, + }, + errMsg: "invalid EKU", + }, + { + name: "negative TTL", + profile: domain.CertificateProfile{ + Name: "Negative TTL", + MaxTTLSeconds: -1, + }, + errMsg: "cannot be negative", + }, + { + name: "short-lived with long TTL", + profile: domain.CertificateProfile{ + Name: "Inconsistent Short-Lived", + AllowShortLived: true, + MaxTTLSeconds: 7200, + }, + errMsg: "short-lived certs must have TTL under 1 hour", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := svc.CreateProfile(tt.profile) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.errMsg) + } + if !contains(err.Error(), tt.errMsg) { + t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error()) + } + }) + } +} + +func TestProfileService_CreateProfile_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.CreateErr = errors.New("db create failed") + svc := NewProfileService(repo, nil) + + _, err := svc.CreateProfile(domain.CertificateProfile{Name: "Valid"}) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_UpdateProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Original"}) + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + updated, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{ + Name: "Updated", + MaxTTLSeconds: 43200, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if updated.ID != "prof-1" { + t.Errorf("expected ID 'prof-1', got '%s'", updated.ID) + } + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_UpdateProfile_ValidationError(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + _, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{Name: ""}) + if err == nil { + t.Fatal("expected validation error, got nil") + } +} + +func TestProfileService_DeleteProfile(t *testing.T) { + repo := newMockProfileRepository() + repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "To Delete"}) + auditRepo := newMockAuditRepository() + auditSvc := NewAuditService(auditRepo) + svc := NewProfileService(repo, auditSvc) + + err := svc.DeleteProfile("prof-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(auditRepo.Events) != 1 { + t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events)) + } +} + +func TestProfileService_DeleteProfile_RepoError(t *testing.T) { + repo := newMockProfileRepository() + repo.DeleteErr = errors.New("db delete failed") + svc := NewProfileService(repo, nil) + + err := svc.DeleteProfile("prof-1") + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProfileService_CreateProfile_ValidShortLived(t *testing.T) { + repo := newMockProfileRepository() + svc := NewProfileService(repo, nil) + + // Short-lived with TTL under 1 hour should succeed + created, err := svc.CreateProfile(domain.CertificateProfile{ + Name: "CI Ephemeral", + AllowShortLived: true, + MaxTTLSeconds: 300, // 5 minutes + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !created.AllowShortLived { + t.Error("expected AllowShortLived to be true") + } +} + +func TestIsShortLived(t *testing.T) { + tests := []struct { + name string + profile domain.CertificateProfile + expected bool + }{ + { + name: "short-lived with 5 min TTL", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 300}, + expected: true, + }, + { + name: "short-lived flag false", + profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 300}, + expected: false, + }, + { + name: "zero TTL with flag", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 0}, + expected: false, + }, + { + name: "TTL at 1 hour boundary", + profile: domain.CertificateProfile{AllowShortLived: true, MaxTTLSeconds: 3600}, + expected: false, + }, + { + name: "standard long-lived", + profile: domain.CertificateProfile{AllowShortLived: false, MaxTTLSeconds: 7776000}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.profile.IsShortLived() + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +// contains checks if a string contains a substring (helper for test assertions). +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/service/renewal.go b/internal/service/renewal.go index 92700eb..f01d3c9 100644 --- a/internal/service/renewal.go +++ b/internal/service/renewal.go @@ -22,6 +22,7 @@ type RenewalService struct { certRepo repository.CertificateRepository jobRepo repository.JobRepository renewalPolicyRepo repository.RenewalPolicyRepository + profileRepo repository.CertificateProfileRepository auditService *AuditService notificationSvc *NotificationService issuerRegistry map[string]IssuerConnector @@ -52,6 +53,7 @@ func NewRenewalService( certRepo repository.CertificateRepository, jobRepo repository.JobRepository, renewalPolicyRepo repository.RenewalPolicyRepository, + profileRepo repository.CertificateProfileRepository, auditService *AuditService, notificationSvc *NotificationService, issuerRegistry map[string]IssuerConnector, @@ -64,6 +66,7 @@ func NewRenewalService( certRepo: certRepo, jobRepo: jobRepo, renewalPolicyRepo: renewalPolicyRepo, + profileRepo: profileRepo, auditService: auditService, notificationSvc: notificationSvc, issuerRegistry: issuerRegistry, @@ -371,6 +374,8 @@ func (s *RenewalService) processRenewalServerKeygen(ctx context.Context, job *do FingerprintSHA256: fingerprint, PEMChain: result.CertPEM + "\n" + result.ChainPEM, CSRPEM: privKeyPEM, // Server mode: stores private key for agent deployment + KeyAlgorithm: domain.KeyAlgorithmRSA, + KeySize: 2048, CreatedAt: time.Now(), } @@ -428,6 +433,22 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai return fmt.Errorf("issuer connector not found for %s", cert.IssuerID) } + // Validate CSR against certificate profile (crypto policy enforcement) + var profile *domain.CertificateProfile + if cert.CertificateProfileID != "" && s.profileRepo != nil { + var profileErr error + profile, profileErr = s.profileRepo.Get(ctx, cert.CertificateProfileID) + if profileErr != nil { + slog.Warn("failed to fetch certificate profile, skipping crypto validation", + "profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", profileErr) + } + } + csrInfo, csrErr := ValidateCSRAgainstProfile(csrPEM, profile) + if csrErr != nil { + s.failJob(ctx, job, fmt.Sprintf("CSR validation failed: %v", csrErr)) + return fmt.Errorf("CSR validation failed: %w", csrErr) + } + // Update job to running if err := s.jobRepo.UpdateStatus(ctx, job.ID, domain.JobStatusRunning, ""); err != nil { return fmt.Errorf("failed to update job status: %w", err) @@ -462,6 +483,10 @@ func (s *RenewalService) CompleteAgentCSRRenewal(ctx context.Context, job *domai CSRPEM: csrPEM, // Agent mode: stores actual CSR, not private key CreatedAt: time.Now(), } + if csrInfo != nil { + version.KeyAlgorithm = csrInfo.KeyAlgorithm + version.KeySize = csrInfo.KeySize + } if err := s.certRepo.CreateVersion(ctx, version); err != nil { s.failJob(ctx, job, fmt.Sprintf("version creation failed: %v", err)) @@ -589,6 +614,73 @@ func (s *RenewalService) RetryFailedJobs(ctx context.Context, maxRetries int) er return nil } +// ExpireShortLivedCertificates finds active certificates with short-lived profiles +// whose TTL has elapsed and marks them as Expired. For certs with TTL < 1 hour, +// expiry is the revocation mechanism — no CRL/OCSP needed. +func (s *RenewalService) ExpireShortLivedCertificates(ctx context.Context) error { + if s.profileRepo == nil { + return nil + } + + // Get all Active certificates and check if any have expired based on their actual expiry time + // This catches short-lived certs that expire between normal renewal check cycles + now := time.Now() + expiring, err := s.certRepo.GetExpiringCertificates(ctx, now) + if err != nil { + return fmt.Errorf("failed to fetch expired certificates: %w", err) + } + + for _, cert := range expiring { + if cert.Status != domain.CertificateStatusActive && cert.Status != domain.CertificateStatusExpiring { + continue + } + + // Only auto-expire certs that have actually passed their expiry time + if cert.ExpiresAt.After(now) { + continue + } + + // Check if this cert has a short-lived profile + if cert.CertificateProfileID == "" { + continue + } + + profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID) + if err != nil { + slog.Warn("failed to fetch profile for short-lived expiry check", + "profile_id", cert.CertificateProfileID, "cert_id", cert.ID, "error", err) + continue + } + + if !profile.IsShortLived() { + continue + } + + // Mark as expired + cert.Status = domain.CertificateStatusExpired + cert.UpdatedAt = now + if err := s.certRepo.Update(ctx, cert); err != nil { + slog.Error("failed to expire short-lived cert", "cert_id", cert.ID, "error", err) + continue + } + + slog.Info("short-lived certificate expired (expiry = revocation)", + "cert_id", cert.ID, "profile_id", cert.CertificateProfileID, + "expired_at", cert.ExpiresAt) + + if auditErr := s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem, + "short_lived_cert_expired", "certificate", cert.ID, + map[string]interface{}{ + "profile_id": cert.CertificateProfileID, + "expired_at": cert.ExpiresAt, + }); auditErr != nil { + slog.Error("failed to record audit event", "error", auditErr) + } + } + + return nil +} + // generateID is a helper to generate unique IDs. In production, use a proper ID generator. func generateID(prefix string) string { return fmt.Sprintf("%s-%d", prefix, time.Now().UnixNano()) diff --git a/internal/service/renewal_test.go b/internal/service/renewal_test.go index 464251f..7a8c361 100644 --- a/internal/service/renewal_test.go +++ b/internal/service/renewal_test.go @@ -30,7 +30,7 @@ func TestCheckExpiringCertificates_SendsThresholdAlerts(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create a cert expiring in 10 days cert := &domain.ManagedCertificate{ @@ -112,7 +112,7 @@ func TestCheckExpiringCertificates_DeduplicatesAlerts(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert cert := &domain.ManagedCertificate{ @@ -192,7 +192,7 @@ func TestCheckExpiringCertificates_SkipsRenewalInProgress(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert with RenewalInProgress status cert := &domain.ManagedCertificate{ @@ -257,7 +257,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpiring(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create active cert that will become expiring // Use an issuer NOT in the registry so no renewal job is created (which would override status) @@ -319,7 +319,7 @@ func TestCheckExpiringCertificates_UpdatesStatusToExpired(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert that is already expired // Use an issuer NOT in the registry so no renewal job is created (which would override status) @@ -381,7 +381,7 @@ func TestCheckExpiringCertificates_CreatesRenewalJob(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create expiring cert with registered issuer cert := &domain.ManagedCertificate{ @@ -447,7 +447,7 @@ func TestCheckExpiringCertificates_SkipsWithoutIssuer(t *testing.T) { // Empty issuer registry issuerRegistry := map[string]IssuerConnector{} - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert with unregistered issuer cert := &domain.ManagedCertificate{ @@ -509,7 +509,7 @@ func TestCheckExpiringCertificates_SkipsDuplicateJobs(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create cert cert := &domain.ManagedCertificate{ @@ -593,7 +593,7 @@ func TestProcessRenewalJob(t *testing.T) { "iss-test": issuerConnector, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create certificate cert := &domain.ManagedCertificate{ @@ -689,7 +689,7 @@ func TestProcessRenewalJob_IssuerFailure(t *testing.T) { "iss-test": issuerConnector, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create certificate cert := &domain.ManagedCertificate{ @@ -771,7 +771,7 @@ func TestRetryFailedJobs(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create failed job with attempts < max_attempts failedJob := &domain.Job{ @@ -836,7 +836,7 @@ func TestProcessRenewalJob_NoCertificate(t *testing.T) { "iss-test": &mockIssuerConnector{}, } - svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry, "server") + svc := NewRenewalService(certRepo, jobRepo, policyRepo, nil, auditSvc, notifSvc, issuerRegistry, "server") // Create job with non-existent certificate job := &domain.Job{ diff --git a/migrations/000003_certificate_profiles.down.sql b/migrations/000003_certificate_profiles.down.sql new file mode 100644 index 0000000..d2abe2f --- /dev/null +++ b/migrations/000003_certificate_profiles.down.sql @@ -0,0 +1,13 @@ +-- Rollback: remove certificate profiles and associated columns + +ALTER TABLE certificate_versions DROP COLUMN IF EXISTS key_algorithm; +ALTER TABLE certificate_versions DROP COLUMN IF EXISTS key_size; + +ALTER TABLE renewal_policies DROP COLUMN IF EXISTS certificate_profile_id; + +DROP INDEX IF EXISTS idx_managed_certificates_profile_id; +ALTER TABLE managed_certificates DROP COLUMN IF EXISTS certificate_profile_id; + +DROP INDEX IF EXISTS idx_certificate_profiles_name; +DROP INDEX IF EXISTS idx_certificate_profiles_enabled; +DROP TABLE IF EXISTS certificate_profiles; diff --git a/migrations/000003_certificate_profiles.up.sql b/migrations/000003_certificate_profiles.up.sql new file mode 100644 index 0000000..267fbc7 --- /dev/null +++ b/migrations/000003_certificate_profiles.up.sql @@ -0,0 +1,53 @@ +-- M11a: Certificate Profiles + Crypto Foundation +-- Named enrollment profiles defining allowed key types, max TTL, required SANs, +-- permitted EKUs, and optional SPIFFE URI SAN patterns. + +-- Table: certificate_profiles +CREATE TABLE IF NOT EXISTS certificate_profiles ( + id TEXT PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description TEXT DEFAULT '', + + -- Crypto policy: which key algorithms and minimum sizes are allowed + -- Example: [{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}] + allowed_key_algorithms JSONB NOT NULL DEFAULT '[{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}]', + + -- Maximum certificate TTL in seconds (0 = no limit, uses issuer default) + -- Short-lived: 300 (5 min), 3600 (1 hour). Standard: 7776000 (90 days), 4060800 (47 days) + max_ttl_seconds INT NOT NULL DEFAULT 0, + + -- Permitted Extended Key Usages + -- Example: ["serverAuth", "clientAuth"] + allowed_ekus JSONB NOT NULL DEFAULT '["serverAuth"]', + + -- Required SAN patterns (regexes that issued certs must match) + -- Example: [".*\\.example\\.com$", ".*\\.internal\\.example\\.com$"] + required_san_patterns JSONB NOT NULL DEFAULT '[]', + + -- Optional SPIFFE URI SAN pattern for workload identity + -- Example: "spiffe://example.com/workload/*" + -- Empty string means no SPIFFE SAN will be minted + spiffe_uri_pattern VARCHAR(512) DEFAULT '', + + -- Whether this profile allows short-lived certs (TTL < 1 hour) + -- When true, expired certs under this profile skip CRL/OCSP (expiry = revocation) + allow_short_lived BOOLEAN NOT NULL DEFAULT false, + + enabled BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_certificate_profiles_name ON certificate_profiles(name); +CREATE INDEX IF NOT EXISTS idx_certificate_profiles_enabled ON certificate_profiles(enabled); + +-- Add certificate_profile_id FK to managed_certificates (nullable for backward compat) +ALTER TABLE managed_certificates ADD COLUMN IF NOT EXISTS certificate_profile_id TEXT REFERENCES certificate_profiles(id) ON DELETE SET NULL; +CREATE INDEX IF NOT EXISTS idx_managed_certificates_profile_id ON managed_certificates(certificate_profile_id); + +-- Add certificate_profile_id FK to renewal_policies (nullable — profile scoping on policies) +ALTER TABLE renewal_policies ADD COLUMN IF NOT EXISTS certificate_profile_id TEXT REFERENCES certificate_profiles(id) ON DELETE SET NULL; + +-- Add key metadata to certificate_versions for audit / compliance +ALTER TABLE certificate_versions ADD COLUMN IF NOT EXISTS key_algorithm VARCHAR(50) DEFAULT ''; +ALTER TABLE certificate_versions ADD COLUMN IF NOT EXISTS key_size INT DEFAULT 0; diff --git a/migrations/seed_demo.sql b/migrations/seed_demo.sql index df0533b..ddbc5ee 100644 --- a/migrations/seed_demo.sql +++ b/migrations/seed_demo.sql @@ -53,6 +53,42 @@ INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, creat ('tgt-nginx-data', 'NGINX Data Services', 'nginx', 'ag-data-prod', '{"cert_path": "/etc/nginx/ssl/cert.pem", "key_path": "/etc/nginx/ssl/key.pem", "reload_command": "nginx -s reload"}', true, NOW(), NOW()) ON CONFLICT (id) DO NOTHING; +-- Certificate Profiles +INSERT INTO certificate_profiles (id, name, description, allowed_key_algorithms, max_ttl_seconds, allowed_ekus, required_san_patterns, spiffe_uri_pattern, allow_short_lived, enabled, created_at, updated_at) VALUES + ('prof-standard-tls', 'Standard TLS', + 'Default profile for web-facing TLS certificates. Requires ECDSA P-256+ or RSA 2048+.', + '[{"algorithm": "ECDSA", "min_size": 256}, {"algorithm": "RSA", "min_size": 2048}]'::jsonb, + 7776000, -- 90 days + '["serverAuth"]'::jsonb, + '[]'::jsonb, + '', false, true, NOW(), NOW()), + + ('prof-internal-mtls', 'Internal mTLS', + 'Mutual TLS profile for internal service-to-service communication.', + '[{"algorithm": "ECDSA", "min_size": 256}]'::jsonb, + 2592000, -- 30 days + '["serverAuth", "clientAuth"]'::jsonb, + '[".*\\.internal\\.example\\.com$"]'::jsonb, + '', false, true, NOW(), NOW()), + + ('prof-short-lived', 'Short-Lived Credential', + 'Ephemeral certificates for CI/CD pipelines and container workloads. TTL under 1 hour, expiry = revocation.', + '[{"algorithm": "ECDSA", "min_size": 256}]'::jsonb, + 300, -- 5 minutes + '["serverAuth", "clientAuth"]'::jsonb, + '[]'::jsonb, + 'spiffe://example.com/workload/*', + true, true, NOW(), NOW()), + + ('prof-high-security', 'High Security', + 'For PCI-DSS and compliance-sensitive workloads. RSA 4096+ or ECDSA P-384+ only.', + '[{"algorithm": "ECDSA", "min_size": 384}, {"algorithm": "RSA", "min_size": 4096}]'::jsonb, + 4060800, -- 47 days (Ballot SC-081v3 target) + '["serverAuth"]'::jsonb, + '[".*\\.example\\.com$"]'::jsonb, + '', false, true, NOW(), NOW()) +ON CONFLICT (id) DO NOTHING; + -- Managed Certificates — varied statuses and expiry dates for realistic dashboard INSERT INTO managed_certificates (id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at) VALUES -- Active, healthy certs diff --git a/web/src/api/client.ts b/web/src/api/client.ts index 828dca6..bf7b8b3 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -1,4 +1,4 @@ -import type { Certificate, CertificateVersion, Agent, Job, Notification, AuditEvent, PolicyRule, PolicyViolation, Issuer, Target, PaginatedResponse } from './types'; +import type { Certificate, CertificateVersion, Agent, Job, Notification, AuditEvent, PolicyRule, PolicyViolation, Issuer, Target, CertificateProfile, PaginatedResponse } from './types'; const BASE = '/api/v1'; @@ -169,5 +169,23 @@ export const createTarget = (data: Partial) => export const deleteTarget = (id: string) => fetchJSON<{ message: string }>(`${BASE}/targets/${id}`, { method: 'DELETE' }); +// Profiles +export const getProfiles = (params: Record = {}) => { + const qs = new URLSearchParams({ page: '1', per_page: '50', ...params }).toString(); + return fetchJSON>(`${BASE}/profiles?${qs}`); +}; + +export const getProfile = (id: string) => + fetchJSON(`${BASE}/profiles/${id}`); + +export const createProfile = (data: Partial) => + fetchJSON(`${BASE}/profiles`, { method: 'POST', body: JSON.stringify(data) }); + +export const updateProfile = (id: string, data: Partial) => + fetchJSON(`${BASE}/profiles/${id}`, { method: 'PUT', body: JSON.stringify(data) }); + +export const deleteProfile = (id: string) => + fetchJSON<{ message: string }>(`${BASE}/profiles/${id}`, { method: 'DELETE' }); + // Health export const getHealth = () => fetchJSON<{ status: string }>('/health'); diff --git a/web/src/api/types.ts b/web/src/api/types.ts index 5a6d9ae..0881cf1 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -129,6 +129,26 @@ export interface Target { created_at: string; } +export interface KeyAlgorithmRule { + algorithm: string; + min_size: number; +} + +export interface CertificateProfile { + id: string; + name: string; + description: string; + allowed_key_algorithms: KeyAlgorithmRule[]; + max_ttl_seconds: number; + allowed_ekus: string[]; + required_san_patterns: string[]; + spiffe_uri_pattern: string; + allow_short_lived: boolean; + enabled: boolean; + created_at: string; + updated_at: string; +} + export interface PaginatedResponse { data: T[]; total: number; diff --git a/web/src/components/Layout.tsx b/web/src/components/Layout.tsx index db6fe56..0a806e1 100644 --- a/web/src/components/Layout.tsx +++ b/web/src/components/Layout.tsx @@ -8,6 +8,7 @@ const nav = [ { to: '/jobs', label: 'Jobs', icon: 'M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15' }, { to: '/notifications', label: 'Notifications', icon: 'M15 17h5l-1.405-1.405A2.032 2.032 0 0118 14.158V11a6.002 6.002 0 00-4-5.659V5a2 2 0 10-4 0v.341C7.67 6.165 6 8.388 6 11v3.159c0 .538-.214 1.055-.595 1.436L4 17h5m6 0v1a3 3 0 11-6 0v-1m6 0H9' }, { to: '/policies', label: 'Policies', icon: 'M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4' }, + { to: '/profiles', label: 'Profiles', icon: 'M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.066 2.573c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.573 1.066c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.066-2.573c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z M15 12a3 3 0 11-6 0 3 3 0 016 0z' }, { to: '/issuers', label: 'Issuers', icon: 'M15 7a2 2 0 012 2m4 0a6 6 0 01-7.743 5.743L11 17H9v2H7v2H4a1 1 0 01-1-1v-2.586a1 1 0 01.293-.707l5.964-5.964A6 6 0 1121 9z' }, { to: '/targets', label: 'Targets', icon: 'M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10' }, { to: '/audit', label: 'Audit Trail', icon: 'M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z' }, diff --git a/web/src/main.tsx b/web/src/main.tsx index f820d6d..f68afa6 100644 --- a/web/src/main.tsx +++ b/web/src/main.tsx @@ -16,6 +16,7 @@ import NotificationsPage from './pages/NotificationsPage'; import PoliciesPage from './pages/PoliciesPage'; import IssuersPage from './pages/IssuersPage'; import TargetsPage from './pages/TargetsPage'; +import ProfilesPage from './pages/ProfilesPage'; import AuditPage from './pages/AuditPage'; import './index.css'; @@ -46,6 +47,7 @@ createRoot(document.getElementById('root')!).render( } /> } /> } /> + } /> } /> } /> } /> diff --git a/web/src/pages/ProfilesPage.tsx b/web/src/pages/ProfilesPage.tsx new file mode 100644 index 0000000..5a7e6e9 --- /dev/null +++ b/web/src/pages/ProfilesPage.tsx @@ -0,0 +1,129 @@ +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; +import { getProfiles, deleteProfile } from '../api/client'; +import PageHeader from '../components/PageHeader'; +import DataTable from '../components/DataTable'; +import type { Column } from '../components/DataTable'; +import StatusBadge from '../components/StatusBadge'; +import ErrorState from '../components/ErrorState'; +import { formatDateTime } from '../api/utils'; +import type { CertificateProfile } from '../api/types'; + +function formatTTL(seconds: number): string { + if (seconds === 0) return 'No limit'; + if (seconds < 60) return `${seconds}s`; + if (seconds < 3600) return `${Math.floor(seconds / 60)}m`; + if (seconds < 86400) return `${Math.floor(seconds / 3600)}h`; + return `${Math.floor(seconds / 86400)}d`; +} + +export default function ProfilesPage() { + const queryClient = useQueryClient(); + + const { data, isLoading, error, refetch } = useQuery({ + queryKey: ['profiles'], + queryFn: () => getProfiles(), + }); + + const deleteMutation = useMutation({ + mutationFn: deleteProfile, + onSuccess: () => queryClient.invalidateQueries({ queryKey: ['profiles'] }), + }); + + const columns: Column[] = [ + { + key: 'name', + label: 'Profile', + render: (p) => ( +
+
{p.name}
+
{p.id}
+ {p.description && ( +
{p.description}
+ )} +
+ ), + }, + { + key: 'algorithms', + label: 'Key Algorithms', + render: (p) => ( +
+ {(p.allowed_key_algorithms || []).map((alg, i) => ( + + {alg.algorithm} {alg.min_size}+ + + ))} +
+ ), + }, + { + key: 'ttl', + label: 'Max TTL', + render: (p) => ( +
+ {formatTTL(p.max_ttl_seconds)} + {p.allow_short_lived && ( + + short-lived + + )} +
+ ), + }, + { + key: 'ekus', + label: 'EKUs', + render: (p) => ( +
+ {(p.allowed_ekus || []).map((eku, i) => ( + {eku} + ))} +
+ ), + }, + { + key: 'spiffe', + label: 'SPIFFE', + render: (p) => ( + p.spiffe_uri_pattern + ? {p.spiffe_uri_pattern} + : + ), + }, + { + key: 'enabled', + label: 'Status', + render: (p) => , + }, + { + key: 'created', + label: 'Created', + render: (p) => {formatDateTime(p.created_at)}, + }, + { + key: 'actions', + label: '', + render: (p) => ( + + ), + }, + ]; + + return ( + <> + +
+ {error ? ( + refetch()} /> + ) : ( + + )} +
+ + ); +}