Implement M4: comprehensive test coverage with 120 tests

Service layer (63 tests): certificate, agent, audit, job, notification,
policy, and renewal services with mock repositories covering threshold
alerting, deduplication, status transitions, and job processing.

Handler layer (46 tests): certificate and agent HTTP handlers using
httptest with mock service interfaces, covering success/error paths,
pagination, JSON marshaling, and path parameter extraction.

Integration (11 subtests): end-to-end certificate lifecycle test
exercising real services and Local CA issuer through HTTP API —
create cert, trigger renewal, process jobs, register agent, heartbeat,
verify audit trail.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-03-15 00:25:01 -04:00
parent 1d1b89c9b5
commit 5553568495
14 changed files with 6767 additions and 1 deletions
+869
View File
@@ -0,0 +1,869 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// MockAgentService is a mock implementation of AgentService interface.
type MockAgentService struct {
ListAgentsFn func(page, perPage int) ([]domain.Agent, int64, error)
GetAgentFn func(id string) (*domain.Agent, error)
RegisterAgentFn func(agent domain.Agent) (*domain.Agent, error)
HeartbeatFn func(agentID string) error
CSRSubmitFn func(agentID string, csrPEM string) (string, error)
CSRSubmitForCertFn func(agentID string, certID string, csrPEM string) (string, error)
CertificatePickupFn func(agentID, certID string) (string, error)
GetWorkFn func(agentID string) ([]domain.Job, error)
GetWorkWithTargetsFn func(agentID string) ([]domain.WorkItem, error)
UpdateJobStatusFn func(agentID string, jobID string, status string, errMsg string) error
}
func (m *MockAgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) {
if m.ListAgentsFn != nil {
return m.ListAgentsFn(page, perPage)
}
return nil, 0, nil
}
func (m *MockAgentService) GetAgent(id string) (*domain.Agent, error) {
if m.GetAgentFn != nil {
return m.GetAgentFn(id)
}
return nil, nil
}
func (m *MockAgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) {
if m.RegisterAgentFn != nil {
return m.RegisterAgentFn(agent)
}
return nil, nil
}
func (m *MockAgentService) Heartbeat(agentID string) error {
if m.HeartbeatFn != nil {
return m.HeartbeatFn(agentID)
}
return nil
}
func (m *MockAgentService) CSRSubmit(agentID string, csrPEM string) (string, error) {
if m.CSRSubmitFn != nil {
return m.CSRSubmitFn(agentID, csrPEM)
}
return "", nil
}
func (m *MockAgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) {
if m.CSRSubmitForCertFn != nil {
return m.CSRSubmitForCertFn(agentID, certID, csrPEM)
}
return "", nil
}
func (m *MockAgentService) CertificatePickup(agentID, certID string) (string, error) {
if m.CertificatePickupFn != nil {
return m.CertificatePickupFn(agentID, certID)
}
return "", nil
}
func (m *MockAgentService) GetWork(agentID string) ([]domain.Job, error) {
if m.GetWorkFn != nil {
return m.GetWorkFn(agentID)
}
return nil, nil
}
func (m *MockAgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) {
if m.GetWorkWithTargetsFn != nil {
return m.GetWorkWithTargetsFn(agentID)
}
return nil, nil
}
func (m *MockAgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error {
if m.UpdateJobStatusFn != nil {
return m.UpdateJobStatusFn(agentID, jobID, status, errMsg)
}
return nil
}
// Test ListAgents - success case
func TestListAgents_Success(t *testing.T) {
now := time.Now()
agent1 := domain.Agent{
ID: "a-prod-001",
Name: "Production Agent",
Hostname: "prod-server-01",
Status: domain.AgentStatusOnline,
LastHeartbeatAt: &now,
RegisteredAt: now,
}
agent2 := domain.Agent{
ID: "a-prod-002",
Name: "API Agent",
Hostname: "api-server-01",
Status: domain.AgentStatusOnline,
LastHeartbeatAt: &now,
RegisteredAt: now,
}
mock := &MockAgentService{
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
if page == 1 && perPage == 50 {
return []domain.Agent{agent1, agent2}, 2, nil
}
return nil, 0, nil
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents?page=1&per_page=50", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListAgents(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response PagedResponse
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.Total != 2 {
t.Errorf("expected total 2, got %d", response.Total)
}
}
// Test ListAgents - method not allowed
func TestListAgents_MethodNotAllowed(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListAgents(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
}
// Test ListAgents - service error
func TestListAgents_ServiceError(t *testing.T) {
mock := &MockAgentService{
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
return nil, 0, ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListAgents(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test GetAgent - success case
func TestGetAgent_Success(t *testing.T) {
now := time.Now()
agent := &domain.Agent{
ID: "a-prod-001",
Name: "Production Agent",
Hostname: "prod-server-01",
Status: domain.AgentStatusOnline,
LastHeartbeatAt: &now,
RegisteredAt: now,
}
mock := &MockAgentService{
GetAgentFn: func(id string) (*domain.Agent, error) {
if id == "a-prod-001" {
return agent, nil
}
return nil, ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetAgent(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response domain.Agent
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.ID != "a-prod-001" {
t.Errorf("expected ID a-prod-001, got %s", response.ID)
}
}
// Test GetAgent - not found
func TestGetAgent_NotFound(t *testing.T) {
mock := &MockAgentService{
GetAgentFn: func(id string) (*domain.Agent, error) {
return nil, ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/nonexistent", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetAgent(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
// Test RegisterAgent - success case
func TestRegisterAgent_Success(t *testing.T) {
now := time.Now()
registered := &domain.Agent{
ID: "a-prod-001",
Name: "Production Agent",
Hostname: "prod-server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
mock := &MockAgentService{
RegisterAgentFn: func(agent domain.Agent) (*domain.Agent, error) {
return registered, nil
},
}
handler := NewAgentHandler(mock)
agentBody := domain.Agent{
Name: "Production Agent",
Hostname: "prod-server-01",
}
body, _ := json.Marshal(agentBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.RegisterAgent(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
}
var response domain.Agent
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.ID != "a-prod-001" {
t.Errorf("expected ID a-prod-001, got %s", response.ID)
}
}
// Test RegisterAgent - invalid body
func TestRegisterAgent_InvalidBody(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader([]byte("invalid json")))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.RegisterAgent(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test Heartbeat - success case
func TestHeartbeat_Success(t *testing.T) {
mock := &MockAgentService{
HeartbeatFn: func(agentID string) error {
if agentID == "a-prod-001" {
return nil
}
return ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/heartbeat", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.Heartbeat(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["status"] != "heartbeat_recorded" {
t.Errorf("expected status 'heartbeat_recorded', got %s", response["status"])
}
}
// Test Heartbeat - service error
func TestHeartbeat_ServiceError(t *testing.T) {
mock := &MockAgentService{
HeartbeatFn: func(agentID string) error {
return ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/heartbeat", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.Heartbeat(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test AgentCSRSubmit - with certificate_id
func TestAgentCSRSubmit_WithCertificateID(t *testing.T) {
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----"
mock := &MockAgentService{
CSRSubmitForCertFn: func(agentID string, certID string, csrPEM string) (string, error) {
if agentID == "a-prod-001" && certID == "mc-prod-001" {
return "csr_submitted", nil
}
return "", ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
reqBody := map[string]string{
"csr_pem": csrPEM,
"certificate_id": "mc-prod-001",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentCSRSubmit(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["status"] != "csr_submitted" {
t.Errorf("expected status 'csr_submitted', got %s", response["status"])
}
}
// Test AgentCSRSubmit - without certificate_id
func TestAgentCSRSubmit_WithoutCertificateID(t *testing.T) {
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----"
mock := &MockAgentService{
CSRSubmitFn: func(agentID string, csrPEM string) (string, error) {
if agentID == "a-prod-001" {
return "csr_submitted", nil
}
return "", ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
reqBody := map[string]string{
"csr_pem": csrPEM,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentCSRSubmit(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
}
}
// Test AgentCSRSubmit - missing CSR PEM
func TestAgentCSRSubmit_MissingCSRPEM(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
reqBody := map[string]string{
"certificate_id": "mc-prod-001",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentCSRSubmit(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test AgentCSRSubmit - invalid body
func TestAgentCSRSubmit_InvalidBody(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader([]byte("invalid")))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentCSRSubmit(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test AgentCertificatePickup - success case
func TestAgentCertificatePickup_Success(t *testing.T) {
certPEM := "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"
mock := &MockAgentService{
CertificatePickupFn: func(agentID, certID string) (string, error) {
if agentID == "a-prod-001" && certID == "mc-prod-001" {
return certPEM, nil
}
return "", ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
// Path structure: /api/v1/agents/{agent_id}/certificates/{cert_id}
// After trim and split: parts[0]="agent_id", parts[1]="certificates", parts[2]="cert_id", parts[3]=""
// Note: handler checks len(parts) < 4, so we need the trailing slash
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/certificates/mc-prod-001/", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.AgentCertificatePickup(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d (body: %s)", http.StatusOK, w.Code, w.Body.String())
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["certificate_pem"] != certPEM {
t.Errorf("expected cert PEM %s, got %s", certPEM, response["certificate_pem"])
}
}
// Test AgentCertificatePickup - not found
func TestAgentCertificatePickup_NotFound(t *testing.T) {
mock := &MockAgentService{
CertificatePickupFn: func(agentID, certID string) (string, error) {
return "", ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/certificates/nonexistent/", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.AgentCertificatePickup(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d (body: %s)", http.StatusNotFound, w.Code, w.Body.String())
}
}
// Test AgentGetWork - success with items
func TestAgentGetWork_Success(t *testing.T) {
workItem := domain.WorkItem{
ID: "j-deploy-001",
Type: domain.JobTypeDeployment,
CertificateID: "mc-prod-001",
TargetID: stringPtr("t-nginx-001"),
TargetType: "NGINX",
Status: domain.JobStatusPending,
}
mock := &MockAgentService{
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
if agentID == "a-prod-001" {
return []domain.WorkItem{workItem}, nil
}
return nil, ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.AgentGetWork(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["count"] != float64(1) {
t.Errorf("expected count 1, got %v", response["count"])
}
}
// Test AgentGetWork - no work items
func TestAgentGetWork_NoItems(t *testing.T) {
mock := &MockAgentService{
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
return nil, nil
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.AgentGetWork(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["count"] != float64(0) {
t.Errorf("expected count 0, got %v", response["count"])
}
}
// Test AgentGetWork - service error
func TestAgentGetWork_ServiceError(t *testing.T) {
mock := &MockAgentService{
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
return nil, ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.AgentGetWork(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test AgentReportJobStatus - success case
func TestAgentReportJobStatus_Success(t *testing.T) {
mock := &MockAgentService{
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
if agentID == "a-prod-001" && jobID == "j-deploy-001" && status == "Completed" {
return nil
}
return ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
statusReq := map[string]string{
"status": "Completed",
}
body, _ := json.Marshal(statusReq)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentReportJobStatus(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["status"] != "updated" {
t.Errorf("expected status 'updated', got %s", response["status"])
}
}
// Test AgentReportJobStatus - with error message
func TestAgentReportJobStatus_WithError(t *testing.T) {
mock := &MockAgentService{
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
if agentID == "a-prod-001" && jobID == "j-deploy-001" && status == "Failed" && errMsg == "timeout" {
return nil
}
return ErrMockNotFound
},
}
handler := NewAgentHandler(mock)
statusReq := map[string]string{
"status": "Failed",
"error": "timeout",
}
body, _ := json.Marshal(statusReq)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentReportJobStatus(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// Test AgentReportJobStatus - missing status
func TestAgentReportJobStatus_MissingStatus(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
statusReq := map[string]string{}
body, _ := json.Marshal(statusReq)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentReportJobStatus(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test AgentReportJobStatus - invalid body
func TestAgentReportJobStatus_InvalidBody(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader([]byte("invalid")))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentReportJobStatus(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test ListAgents - invalid pagination parameters
func TestListAgents_InvalidPagination(t *testing.T) {
mock := &MockAgentService{
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
// Should default to page=1, perPage=50 if invalid
if page == 1 && perPage == 50 {
return []domain.Agent{}, 0, nil
}
return nil, 0, nil
},
}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents?page=invalid&per_page=invalid", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListAgents(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// Test GetAgent - empty ID
func TestGetAgent_EmptyID(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetAgent(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test RegisterAgent - service error
func TestRegisterAgent_ServiceError(t *testing.T) {
mock := &MockAgentService{
RegisterAgentFn: func(agent domain.Agent) (*domain.Agent, error) {
return nil, ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
agentBody := domain.Agent{
Name: "Production Agent",
Hostname: "prod-server-01",
}
body, _ := json.Marshal(agentBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.RegisterAgent(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test Heartbeat - empty agent ID
func TestHeartbeat_EmptyAgentID(t *testing.T) {
mock := &MockAgentService{}
handler := NewAgentHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents//heartbeat", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.Heartbeat(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test AgentCSRSubmit - service error
func TestAgentCSRSubmit_ServiceError(t *testing.T) {
mock := &MockAgentService{
CSRSubmitFn: func(agentID string, csrPEM string) (string, error) {
return "", ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
reqBody := map[string]string{
"csr_pem": "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentCSRSubmit(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test AgentReportJobStatus - service error
func TestAgentReportJobStatus_ServiceError(t *testing.T) {
mock := &MockAgentService{
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
return ErrMockServiceFailed
},
}
handler := NewAgentHandler(mock)
statusReq := map[string]string{
"status": "Completed",
}
body, _ := json.Marshal(statusReq)
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.AgentReportJobStatus(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Helper function to create a string pointer
func stringPtr(s string) *string {
return &s
}
@@ -0,0 +1,704 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/domain"
)
// MockCertificateService is a mock implementation of CertificateService interface.
type MockCertificateService struct {
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificateFn func(id string) error
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewalFn func(certID string) error
TriggerDeploymentFn func(certID string, targetID string) error
}
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if m.ListCertificatesFn != nil {
return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage)
}
return nil, 0, nil
}
func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
if m.GetCertificateFn != nil {
return m.GetCertificateFn(id)
}
return nil, nil
}
func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.CreateCertificateFn != nil {
return m.CreateCertificateFn(cert)
}
return nil, nil
}
func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.UpdateCertificateFn != nil {
return m.UpdateCertificateFn(id, cert)
}
return nil, nil
}
func (m *MockCertificateService) ArchiveCertificate(id string) error {
if m.ArchiveCertificateFn != nil {
return m.ArchiveCertificateFn(id)
}
return nil
}
func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if m.GetCertificateVersionsFn != nil {
return m.GetCertificateVersionsFn(certID, page, perPage)
}
return nil, 0, nil
}
func (m *MockCertificateService) TriggerRenewal(certID string) error {
if m.TriggerRenewalFn != nil {
return m.TriggerRenewalFn(certID)
}
return nil
}
func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error {
if m.TriggerDeploymentFn != nil {
return m.TriggerDeploymentFn(certID, targetID)
}
return nil
}
// Helper function to create context with request ID.
func contextWithRequestID() context.Context {
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
}
// Test ListCertificates - success case
func TestListCertificates_Success(t *testing.T) {
cert1 := domain.ManagedCertificate{
ID: "mc-prod-001",
Name: "Production Cert",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
Environment: "prod",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
cert2 := domain.ManagedCertificate{
ID: "mc-prod-002",
Name: "API Cert",
CommonName: "api.example.com",
Status: domain.CertificateStatusActive,
Environment: "prod",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
mock := &MockCertificateService{
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if page == 1 && perPage == 50 {
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
}
return nil, 0, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?page=1&per_page=50", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response PagedResponse
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.Total != 2 {
t.Errorf("expected total 2, got %d", response.Total)
}
if response.Page != 1 {
t.Errorf("expected page 1, got %d", response.Page)
}
if response.PerPage != 50 {
t.Errorf("expected per_page 50, got %d", response.PerPage)
}
}
// Test ListCertificates - with filters
func TestListCertificates_WithFilters(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if status == "Active" && environment == "prod" {
return []domain.ManagedCertificate{}, 0, nil
}
return nil, 0, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?status=Active&environment=prod&page=1&per_page=25", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// Test ListCertificates - invalid method
func TestListCertificates_MethodNotAllowed(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
}
// Test ListCertificates - service error
func TestListCertificates_ServiceError(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
return nil, 0, ErrMockServiceFailed
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test GetCertificate - success case
func TestGetCertificate_Success(t *testing.T) {
cert := &domain.ManagedCertificate{
ID: "mc-prod-001",
Name: "Production Cert",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
Environment: "prod",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" {
return cert, nil
}
return nil, ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCertificate(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response domain.ManagedCertificate
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.ID != "mc-prod-001" {
t.Errorf("expected ID mc-prod-001, got %s", response.ID)
}
}
// Test GetCertificate - not found
func TestGetCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/nonexistent", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCertificate(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
// Test GetCertificate - empty ID
func TestGetCertificate_EmptyID(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test CreateCertificate - success case
func TestCreateCertificate_Success(t *testing.T) {
now := time.Now()
created := &domain.ManagedCertificate{
ID: "mc-prod-001",
Name: "Production Cert",
CommonName: "example.com",
Status: domain.CertificateStatusPending,
Environment: "prod",
CreatedAt: now,
UpdatedAt: now,
}
mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return created, nil
},
}
handler := NewCertificateHandler(mock)
certBody := domain.ManagedCertificate{
Name: "Production Cert",
CommonName: "example.com",
}
body, _ := json.Marshal(certBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.CreateCertificate(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
}
var response domain.ManagedCertificate
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.ID != "mc-prod-001" {
t.Errorf("expected ID mc-prod-001, got %s", response.ID)
}
}
// Test CreateCertificate - invalid request body
func TestCreateCertificate_InvalidBody(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader([]byte("invalid json")))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.CreateCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test CreateCertificate - service error
func TestCreateCertificate_ServiceError(t *testing.T) {
mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return nil, ErrMockServiceFailed
},
}
handler := NewCertificateHandler(mock)
certBody := domain.ManagedCertificate{
Name: "Production Cert",
CommonName: "example.com",
}
body, _ := json.Marshal(certBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.CreateCertificate(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test UpdateCertificate - success case
func TestUpdateCertificate_Success(t *testing.T) {
updated := &domain.ManagedCertificate{
ID: "mc-prod-001",
Name: "Updated Cert",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
Environment: "prod",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
mock := &MockCertificateService{
UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" {
return updated, nil
}
return nil, ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
certBody := domain.ManagedCertificate{
Name: "Updated Cert",
}
body, _ := json.Marshal(certBody)
req := httptest.NewRequest(http.MethodPut, "/api/v1/certificates/mc-prod-001", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.UpdateCertificate(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response domain.ManagedCertificate
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.Name != "Updated Cert" {
t.Errorf("expected name 'Updated Cert', got %s", response.Name)
}
}
// Test UpdateCertificate - invalid body
func TestUpdateCertificate_InvalidBody(t *testing.T) {
mock := &MockCertificateService{}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPut, "/api/v1/certificates/mc-prod-001", bytes.NewReader([]byte("invalid")))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.UpdateCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
// Test ArchiveCertificate - success case
func TestArchiveCertificate_Success(t *testing.T) {
mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error {
if id == "mc-prod-001" {
return nil
}
return ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/mc-prod-001", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ArchiveCertificate(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("expected status %d, got %d", http.StatusNoContent, w.Code)
}
}
// Test ArchiveCertificate - not found
func TestArchiveCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error {
return ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/nonexistent", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ArchiveCertificate(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test GetCertificateVersions - success case
func TestGetCertificateVersions_Success(t *testing.T) {
ver1 := domain.CertificateVersion{
ID: "cv-001",
CertificateID: "mc-prod-001",
SerialNumber: "ABC123",
FingerprintSHA256: "abc123...",
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 365),
CreatedAt: time.Now(),
}
mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if certID == "mc-prod-001" {
return []domain.CertificateVersion{ver1}, 1, nil
}
return nil, 0, ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001/versions?page=1&per_page=50", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCertificateVersions(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var response PagedResponse
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response.Total != 1 {
t.Errorf("expected total 1, got %d", response.Total)
}
}
// Test GetCertificateVersions - not found
func TestGetCertificateVersions_NotFound(t *testing.T) {
mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
return nil, 0, ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/nonexistent/versions", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.GetCertificateVersions(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
// Test TriggerRenewal - success case
func TestTriggerRenewal_Success(t *testing.T) {
mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error {
if certID == "mc-prod-001" {
return nil
}
return ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/renew", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.TriggerRenewal(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["status"] != "renewal_triggered" {
t.Errorf("expected status 'renewal_triggered', got %s", response["status"])
}
}
// Test TriggerRenewal - service error
func TestTriggerRenewal_ServiceError(t *testing.T) {
mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error {
return ErrMockServiceFailed
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/renew", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.TriggerRenewal(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
// Test TriggerDeployment - success case
func TestTriggerDeployment_Success(t *testing.T) {
mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error {
if certID == "mc-prod-001" {
return nil
}
return ErrMockNotFound
},
}
handler := NewCertificateHandler(mock)
deployReq := map[string]string{"target_id": "t-nginx-001"}
body, _ := json.Marshal(deployReq)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deploy", bytes.NewReader(body))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.TriggerDeployment(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
}
var response map[string]string
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if response["status"] != "deployment_triggered" {
t.Errorf("expected status 'deployment_triggered', got %s", response["status"])
}
}
// Test TriggerDeployment - without target ID
func TestTriggerDeployment_NoTargetID(t *testing.T) {
mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error {
// Should accept empty targetID (deploy to all)
return nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deploy", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.TriggerDeployment(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
}
}
// Test ListCertificates - invalid page parameter
func TestListCertificates_InvalidPageParam(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
// Should default to page 1
if page == 1 {
return []domain.ManagedCertificate{}, 0, nil
}
return nil, 0, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?page=invalid&per_page=50", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// Test ListCertificates - per_page exceeds max
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
mock := &MockCertificateService{
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
// Should cap perPage at 500
if perPage == 50 { // defaults to 50 if > 500
return []domain.ManagedCertificate{}, 0, nil
}
return nil, 0, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?per_page=1000", nil)
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.ListCertificates(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
+11
View File
@@ -0,0 +1,11 @@
package handler
import "errors"
var (
// Mock errors for testing
ErrMockServiceFailed = errors.New("mock service error")
ErrMockNotFound = errors.New("mock not found error")
ErrMockUnauthorized = errors.New("mock unauthorized error")
ErrMockConflict = errors.New("mock conflict error")
)
+996
View File
@@ -0,0 +1,996 @@
package integration
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/shankar0123/certctl/internal/api/handler"
"github.com/shankar0123/certctl/internal/api/router"
"github.com/shankar0123/certctl/internal/connector/issuer/local"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
"github.com/shankar0123/certctl/internal/service"
)
// TestCertificateLifecycle exercises the full certificate lifecycle:
// create -> renew -> process jobs -> verify versions -> register agent -> heartbeat -> audit trail
func TestCertificateLifecycle(t *testing.T) {
ctx := context.Background()
// Setup: Create in-memory mock repositories
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
auditRepo := newMockAuditRepository()
agentRepo := newMockAgentRepository()
targetRepo := newMockTargetRepository()
notifRepo := newMockNotificationRepository()
policyRepo := newMockPolicyRepository()
renewalPolicyRepo := newMockRenewalPolicyRepository()
issuerRepo := newMockIssuerRepository()
// Create logger
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
// Initialize Local CA issuer connector (real implementation, no mock)
localCA := local.New(nil, logger)
// Build issuer registry with adapter
issuerRegistry := map[string]service.IssuerConnector{
"iss-local": service.NewIssuerConnectorAdapter(localCA),
}
// Initialize services (following dependency graph)
auditService := service.NewAuditService(auditRepo)
policyService := service.NewPolicyService(policyRepo, auditService)
certificateService := service.NewCertificateService(certRepo, policyService, auditService)
notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier))
renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry)
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
issuerService := service.NewIssuerService(issuerRepo, auditService)
// Initialize handlers
certificateHandler := handler.NewCertificateHandler(certificateService)
issuerHandler := handler.NewIssuerHandler(issuerService)
targetHandler := handler.NewTargetHandler(&mockTargetService{targetRepo: targetRepo, auditService: auditService})
agentHandler := handler.NewAgentHandler(agentService)
jobHandler := handler.NewJobHandler(jobService)
policyHandler := handler.NewPolicyHandler(policyService)
teamHandler := handler.NewTeamHandler(&mockTeamService{})
ownerHandler := handler.NewOwnerHandler(&mockOwnerService{})
auditHandler := handler.NewAuditHandler(auditService)
notificationHandler := handler.NewNotificationHandler(notificationService)
healthHandler := handler.NewHealthHandler()
// Create router and register handlers
r := router.New()
r.RegisterHandlers(
certificateHandler,
issuerHandler,
targetHandler,
agentHandler,
jobHandler,
policyHandler,
teamHandler,
ownerHandler,
auditHandler,
notificationHandler,
healthHandler,
)
// Create test server
server := httptest.NewServer(r)
defer server.Close()
// ======================
// Step 1: Check health
// ======================
t.Run("HealthCheck", func(t *testing.T) {
resp, err := http.Get(server.URL + "/health")
if err != nil {
t.Fatalf("GET /health failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
var body map[string]string
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if body["status"] != "healthy" {
t.Errorf("expected status=healthy, got %s", body["status"])
}
})
// ======================
// Step 2: Create certificate
// ======================
var certID string
t.Run("CreateCertificate", func(t *testing.T) {
now := time.Now()
payload := map[string]interface{}{
"name": "Example Certificate",
"common_name": "example.com",
"sans": []string{"www.example.com", "api.example.com"},
"environment": "production",
"owner_id": "owner-alice",
"team_id": "team-platform",
"issuer_id": "iss-local",
"target_ids": []string{},
"renewal_policy_id": "policy-standard",
"status": "Pending",
"expires_at": now.AddDate(1, 0, 0),
"tags": map[string]string{"environment": "prod"},
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal payload: %v", err)
}
resp, err := http.Post(
server.URL+"/api/v1/certificates",
"application/json",
bytes.NewReader(body),
)
if err != nil {
t.Fatalf("POST /api/v1/certificates failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 201, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
var cert domain.ManagedCertificate
if err := json.NewDecoder(resp.Body).Decode(&cert); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if cert.ID == "" {
t.Fatalf("response missing id field")
}
certID = cert.ID
t.Logf("Created certificate with ID: %s", certID)
})
// ======================
// Step 3: Verify certificate
// ======================
t.Run("GetCertificate", func(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/certificates/" + certID)
if err != nil {
t.Fatalf("GET /api/v1/certificates/{id} failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
var cert domain.ManagedCertificate
if err := json.NewDecoder(resp.Body).Decode(&cert); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if cert.ID != certID {
t.Errorf("expected cert ID %s, got %s", certID, cert.ID)
}
if cert.CommonName != "example.com" {
t.Errorf("expected common_name example.com, got %s", cert.CommonName)
}
if len(cert.SANs) != 2 {
t.Errorf("expected 2 SANs, got %d", len(cert.SANs))
}
})
// ======================
// Step 4: Trigger renewal
// ======================
t.Run("TriggerRenewal", func(t *testing.T) {
resp, err := http.Post(
server.URL+"/api/v1/certificates/"+certID+"/renew",
"application/json",
nil,
)
if err != nil {
t.Fatalf("POST /api/v1/certificates/{id}/renew failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 202, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
})
// ======================
// Step 5: Process jobs (simulate scheduler)
// ======================
t.Run("ProcessPendingJobs", func(t *testing.T) {
// Jobs should have been created by the renewal trigger.
// Process them using the job service directly.
if err := jobService.ProcessPendingJobs(ctx); err != nil {
t.Fatalf("failed to process pending jobs: %v", err)
}
// Verify that jobs were processed
jobs, err := jobRepo.ListByStatus(ctx, domain.JobStatusCompleted)
if err != nil {
t.Fatalf("failed to list completed jobs: %v", err)
}
// We expect at least one renewal job to have been processed
if len(jobs) == 0 {
t.Logf("Warning: no completed jobs found. This may indicate the renewal job wasn't processed.")
// Check pending jobs instead
pending, _ := jobRepo.ListByStatus(ctx, domain.JobStatusPending)
t.Logf("Pending jobs: %d", len(pending))
}
})
// ======================
// Step 6: Verify certificate versions
// ======================
t.Run("GetCertificateVersions", func(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/certificates/" + certID + "/versions")
if err != nil {
t.Fatalf("GET /api/v1/certificates/{id}/versions failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
var respBody map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
// Extract data field which contains the versions array
dataField := respBody["data"]
if dataField == nil {
t.Logf("No versions found yet - this is expected if renewal is still in progress")
} else {
versions, ok := dataField.([]interface{})
if !ok {
t.Errorf("expected data to be array, got %T", dataField)
} else if len(versions) > 0 {
t.Logf("Found %d certificate versions", len(versions))
// Verify the first version has required fields
if version, ok := versions[0].(map[string]interface{}); ok {
if version["pem_chain"] == nil || version["pem_chain"] == "" {
t.Errorf("certificate version missing pem_chain")
}
if version["serial_number"] == nil || version["serial_number"] == "" {
t.Errorf("certificate version missing serial_number")
}
}
}
}
})
// ======================
// Step 7: Register agent
// ======================
var agentID string
t.Run("RegisterAgent", func(t *testing.T) {
payload := map[string]string{
"name": "agent-prod-1",
"hostname": "prod-server-01.example.com",
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal payload: %v", err)
}
resp, err := http.Post(
server.URL+"/api/v1/agents",
"application/json",
bytes.NewReader(body),
)
if err != nil {
t.Fatalf("POST /api/v1/agents failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 201, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
// The handler returns the agent directly, not wrapped
var agent domain.Agent
if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
agentID = agent.ID
if agentID == "" {
t.Fatalf("agent id is empty")
}
t.Logf("Registered agent with ID: %s", agentID)
})
// ======================
// Step 8: Agent heartbeat
// ======================
t.Run("AgentHeartbeat", func(t *testing.T) {
payload := map[string]string{
"agent_id": agentID,
}
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal payload: %v", err)
}
resp, err := http.Post(
server.URL+"/api/v1/agents/"+agentID+"/heartbeat",
"application/json",
bytes.NewReader(body),
)
if err != nil {
t.Fatalf("POST /api/v1/agents/{id}/heartbeat failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
// Verify agent heartbeat was updated
agent, err := agentRepo.Get(ctx, agentID)
if err != nil {
t.Fatalf("failed to get agent: %v", err)
}
if agent.LastHeartbeatAt == nil {
t.Errorf("agent LastHeartbeatAt was not updated")
}
})
// ======================
// Step 9: List audit events
// ======================
t.Run("ListAuditEvents", func(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/audit?page=1&per_page=50")
if err != nil {
t.Fatalf("GET /api/v1/audit failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
var respBody map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
// Extract data field which contains the events array
dataField := respBody["data"]
if dataField == nil {
t.Logf("No audit events found")
} else {
events, ok := dataField.([]interface{})
if !ok {
t.Errorf("expected data to be array, got %T", dataField)
} else {
t.Logf("Found %d audit events", len(events))
if len(events) == 0 {
t.Logf("Warning: no audit events found. Expected events for certificate_created, agent_registered, etc.")
}
// Verify we have expected event types
eventTypes := make(map[string]int)
for _, evt := range events {
if eventMap, ok := evt.(map[string]interface{}); ok {
if action, ok := eventMap["action"].(string); ok {
eventTypes[action]++
}
}
}
t.Logf("Audit event types: %v", eventTypes)
}
}
})
// ======================
// Step 10: Get agent and verify status
// ======================
t.Run("GetAgent", func(t *testing.T) {
resp, err := http.Get(server.URL + "/api/v1/agents/" + agentID)
if err != nil {
t.Fatalf("GET /api/v1/agents/{id} failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
}
var agent domain.Agent
if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if agent.ID != agentID {
t.Errorf("expected agent ID %s, got %s", agentID, agent.ID)
}
if agent.Status != domain.AgentStatusOnline {
t.Errorf("expected agent status Online, got %s", agent.Status)
}
})
// ======================
// Summary
// ======================
t.Run("Summary", func(t *testing.T) {
totalCerts, _, _ := certRepo.List(ctx, &repository.CertificateFilter{})
totalJobs, _ := jobRepo.List(ctx)
totalAgents, _ := agentRepo.List(ctx)
totalAuditEvents, _ := auditRepo.List(ctx, &repository.AuditFilter{})
t.Logf("=== Integration Test Summary ===")
t.Logf("Certificates: %d", len(totalCerts))
t.Logf("Jobs: %d", len(totalJobs))
t.Logf("Agents: %d", len(totalAgents))
t.Logf("Audit Events: %d", len(totalAuditEvents))
if len(totalCerts) == 0 {
t.Error("Expected at least 1 certificate")
}
if len(totalAgents) == 0 {
t.Error("Expected at least 1 agent")
}
if len(totalAuditEvents) == 0 {
t.Logf("Warning: Expected audit events, but none found")
}
})
}
// Mock repository implementations for integration testing
// These are simple in-memory implementations similar to testutil_test.go patterns
type mockCertificateRepository struct {
certs map[string]*domain.ManagedCertificate
versions map[string][]*domain.CertificateVersion
}
func newMockCertificateRepository() *mockCertificateRepository {
return &mockCertificateRepository{
certs: make(map[string]*domain.ManagedCertificate),
versions: make(map[string][]*domain.CertificateVersion),
}
}
func (m *mockCertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
var certs []*domain.ManagedCertificate
for _, c := range m.certs {
certs = append(certs, c)
}
return certs, len(certs), nil
}
func (m *mockCertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
cert, ok := m.certs[id]
if !ok {
return nil, fmt.Errorf("certificate not found")
}
return cert, nil
}
func (m *mockCertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
m.certs[cert.ID] = cert
return nil
}
func (m *mockCertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
m.certs[cert.ID] = cert
return nil
}
func (m *mockCertificateRepository) Archive(ctx context.Context, id string) error {
cert, ok := m.certs[id]
if !ok {
return fmt.Errorf("certificate not found")
}
cert.Status = domain.CertificateStatusArchived
return nil
}
func (m *mockCertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
return m.versions[certID], nil
}
func (m *mockCertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
m.versions[version.CertificateID] = append(m.versions[version.CertificateID], version)
return nil
}
func (m *mockCertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
var expiring []*domain.ManagedCertificate
for _, c := range m.certs {
if c.ExpiresAt.Before(before) {
expiring = append(expiring, c)
}
}
return expiring, nil
}
type mockJobRepository struct {
jobs map[string]*domain.Job
}
func newMockJobRepository() *mockJobRepository {
return &mockJobRepository{
jobs: make(map[string]*domain.Job),
}
}
func (m *mockJobRepository) List(ctx context.Context) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.jobs {
jobs = append(jobs, j)
}
return jobs, nil
}
func (m *mockJobRepository) Get(ctx context.Context, id string) (*domain.Job, error) {
job, ok := m.jobs[id]
if !ok {
return nil, fmt.Errorf("job not found")
}
return job, nil
}
func (m *mockJobRepository) Create(ctx context.Context, job *domain.Job) error {
m.jobs[job.ID] = job
return nil
}
func (m *mockJobRepository) Update(ctx context.Context, job *domain.Job) error {
m.jobs[job.ID] = job
return nil
}
func (m *mockJobRepository) Delete(ctx context.Context, id string) error {
delete(m.jobs, id)
return nil
}
func (m *mockJobRepository) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.jobs {
if j.Status == status {
jobs = append(jobs, j)
}
}
return jobs, nil
}
func (m *mockJobRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.jobs {
if j.CertificateID == certID {
jobs = append(jobs, j)
}
}
return jobs, nil
}
func (m *mockJobRepository) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
job, ok := m.jobs[id]
if !ok {
return fmt.Errorf("job not found")
}
job.Status = status
if errMsg != "" {
job.LastError = &errMsg
}
return nil
}
func (m *mockJobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.jobs {
if j.Type == jobType && j.Status == domain.JobStatusPending {
jobs = append(jobs, j)
}
}
return jobs, nil
}
type mockAuditRepository struct {
events []*domain.AuditEvent
}
func newMockAuditRepository() *mockAuditRepository {
return &mockAuditRepository{
events: make([]*domain.AuditEvent, 0),
}
}
func (m *mockAuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error {
m.events = append(m.events, event)
return nil
}
func (m *mockAuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
return m.events, nil
}
type mockAgentRepository struct {
agents map[string]*domain.Agent
}
func newMockAgentRepository() *mockAgentRepository {
return &mockAgentRepository{
agents: make(map[string]*domain.Agent),
}
}
func (m *mockAgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
var agents []*domain.Agent
for _, a := range m.agents {
agents = append(agents, a)
}
return agents, nil
}
func (m *mockAgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) {
agent, ok := m.agents[id]
if !ok {
return nil, fmt.Errorf("agent not found")
}
return agent, nil
}
func (m *mockAgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
m.agents[agent.ID] = agent
return nil
}
func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
m.agents[agent.ID] = agent
return nil
}
func (m *mockAgentRepository) Delete(ctx context.Context, id string) error {
delete(m.agents, id)
return nil
}
func (m *mockAgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
agent, ok := m.agents[id]
if !ok {
return fmt.Errorf("agent not found")
}
now := time.Now()
agent.LastHeartbeatAt = &now
return nil
}
func (m *mockAgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
for _, a := range m.agents {
if a.APIKeyHash == keyHash {
return a, nil
}
}
return nil, fmt.Errorf("agent not found")
}
type mockTargetRepository struct {
targets map[string]*domain.DeploymentTarget
}
func newMockTargetRepository() *mockTargetRepository {
return &mockTargetRepository{
targets: make(map[string]*domain.DeploymentTarget),
}
}
func (m *mockTargetRepository) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
var targets []*domain.DeploymentTarget
for _, t := range m.targets {
targets = append(targets, t)
}
return targets, nil
}
func (m *mockTargetRepository) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
target, ok := m.targets[id]
if !ok {
return nil, fmt.Errorf("target not found")
}
return target, nil
}
func (m *mockTargetRepository) Create(ctx context.Context, target *domain.DeploymentTarget) error {
m.targets[target.ID] = target
return nil
}
func (m *mockTargetRepository) Update(ctx context.Context, target *domain.DeploymentTarget) error {
m.targets[target.ID] = target
return nil
}
func (m *mockTargetRepository) Delete(ctx context.Context, id string) error {
delete(m.targets, id)
return nil
}
func (m *mockTargetRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
return m.List(ctx)
}
type mockNotificationRepository struct {
notifications []*domain.NotificationEvent
}
func newMockNotificationRepository() *mockNotificationRepository {
return &mockNotificationRepository{
notifications: make([]*domain.NotificationEvent, 0),
}
}
func (m *mockNotificationRepository) Create(ctx context.Context, notif *domain.NotificationEvent) error {
m.notifications = append(m.notifications, notif)
return nil
}
func (m *mockNotificationRepository) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
return m.notifications, nil
}
func (m *mockNotificationRepository) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
for _, n := range m.notifications {
if n.ID == id {
n.Status = status
return nil
}
}
return fmt.Errorf("notification not found")
}
type mockPolicyRepository struct {
rules map[string]*domain.PolicyRule
violations []*domain.PolicyViolation
}
func newMockPolicyRepository() *mockPolicyRepository {
return &mockPolicyRepository{
rules: make(map[string]*domain.PolicyRule),
violations: make([]*domain.PolicyViolation, 0),
}
}
func (m *mockPolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
var rules []*domain.PolicyRule
for _, r := range m.rules {
rules = append(rules, r)
}
return rules, nil
}
func (m *mockPolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
rule, ok := m.rules[id]
if !ok {
return nil, fmt.Errorf("rule not found")
}
return rule, nil
}
func (m *mockPolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
m.rules[rule.ID] = rule
return nil
}
func (m *mockPolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
m.rules[rule.ID] = rule
return nil
}
func (m *mockPolicyRepository) DeleteRule(ctx context.Context, id string) error {
delete(m.rules, id)
return nil
}
func (m *mockPolicyRepository) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
m.violations = append(m.violations, violation)
return nil
}
func (m *mockPolicyRepository) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
return m.violations, nil
}
type mockRenewalPolicyRepository struct {
policies map[string]*domain.RenewalPolicy
}
func newMockRenewalPolicyRepository() *mockRenewalPolicyRepository {
return &mockRenewalPolicyRepository{
policies: make(map[string]*domain.RenewalPolicy),
}
}
func (m *mockRenewalPolicyRepository) Get(ctx context.Context, id string) (*domain.RenewalPolicy, error) {
policy, ok := m.policies[id]
if !ok {
// Return default policy
return &domain.RenewalPolicy{
ID: id,
Name: "Default Policy",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 3600,
AlertThresholdsDays: domain.DefaultAlertThresholds(),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}, nil
}
return policy, nil
}
func (m *mockRenewalPolicyRepository) List(ctx context.Context) ([]*domain.RenewalPolicy, error) {
var policies []*domain.RenewalPolicy
for _, p := range m.policies {
policies = append(policies, p)
}
return policies, nil
}
type mockIssuerRepository struct {
issuers map[string]*domain.Issuer
}
func newMockIssuerRepository() *mockIssuerRepository {
return &mockIssuerRepository{
issuers: make(map[string]*domain.Issuer),
}
}
func (m *mockIssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
var issuers []*domain.Issuer
for _, i := range m.issuers {
issuers = append(issuers, i)
}
return issuers, nil
}
func (m *mockIssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
issuer, ok := m.issuers[id]
if !ok {
return nil, fmt.Errorf("issuer not found")
}
return issuer, nil
}
func (m *mockIssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
m.issuers[issuer.ID] = issuer
return nil
}
func (m *mockIssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
m.issuers[issuer.ID] = issuer
return nil
}
func (m *mockIssuerRepository) Delete(ctx context.Context, id string) error {
delete(m.issuers, id)
return nil
}
// Mock service implementations for handlers that need them but aren't tested
type mockTargetService struct {
targetRepo *mockTargetRepository
auditService *service.AuditService
}
func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
targets, err := m.targetRepo.List(context.Background())
if err != nil {
return nil, 0, err
}
var result []domain.DeploymentTarget
for _, t := range targets {
result = append(result, *t)
}
return result, int64(len(result)), nil
}
func (m *mockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
return m.targetRepo.Get(context.Background(), id)
}
func (m *mockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
if err := m.targetRepo.Create(context.Background(), &target); err != nil {
return nil, err
}
return &target, nil
}
func (m *mockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
target.ID = id
if err := m.targetRepo.Update(context.Background(), &target); err != nil {
return nil, err
}
return &target, nil
}
func (m *mockTargetService) DeleteTarget(id string) error {
return m.targetRepo.Delete(context.Background(), id)
}
type mockTeamService struct{}
func (m *mockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
return []domain.Team{}, 0, nil
}
func (m *mockTeamService) GetTeam(id string) (*domain.Team, error) {
return nil, fmt.Errorf("team not found")
}
func (m *mockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
return &team, nil
}
func (m *mockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
team.ID = id
return &team, nil
}
func (m *mockTeamService) DeleteTeam(id string) error {
return nil
}
type mockOwnerService struct{}
func (m *mockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
return []domain.Owner{}, 0, nil
}
func (m *mockOwnerService) GetOwner(id string) (*domain.Owner, error) {
return nil, fmt.Errorf("owner not found")
}
func (m *mockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
return &owner, nil
}
func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
owner.ID = id
return &owner, nil
}
func (m *mockOwnerService) DeleteOwner(id string) error {
return nil
}
+467
View File
@@ -0,0 +1,467 @@
package service
import (
"context"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
func TestRegisterAgent(t *testing.T) {
ctx := context.Background()
agentRepo := &mockAgentRepo{
Agents: make(map[string]*domain.Agent),
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
agent, apiKey, err := agentService.Register(ctx, "prod-agent-1", "server-01.example.com")
if err != nil {
t.Fatalf("Register failed: %v", err)
}
if agent.Name != "prod-agent-1" {
t.Errorf("expected name prod-agent-1, got %s", agent.Name)
}
if agent.Hostname != "server-01.example.com" {
t.Errorf("expected hostname server-01.example.com, got %s", agent.Hostname)
}
if agent.Status != domain.AgentStatusOnline {
t.Errorf("expected status Online, got %s", agent.Status)
}
if apiKey == "" {
t.Fatal("expected non-empty API key")
}
if len(agentRepo.Agents) != 1 {
t.Errorf("expected 1 agent in repo, got %d", len(agentRepo.Agents))
}
}
func TestHeartbeat(t *testing.T) {
ctx := context.Background()
now := time.Now()
agent := &domain.Agent{
ID: "agent-001",
Name: "prod-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash123",
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
err := agentService.HeartbeatWithContext(ctx, "agent-001")
if err != nil {
t.Fatalf("Heartbeat failed: %v", err)
}
if _, ok := agentRepo.HeartbeatUpdates["agent-001"]; !ok {
t.Fatal("heartbeat not recorded")
}
}
func TestHeartbeat_NotFound(t *testing.T) {
ctx := context.Background()
agentRepo := &mockAgentRepo{
Agents: make(map[string]*domain.Agent),
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
err := agentService.HeartbeatWithContext(ctx, "nonexistent")
if err == nil {
t.Fatal("expected error for nonexistent agent")
}
}
func TestGetPendingWork(t *testing.T) {
ctx := context.Background()
now := time.Now()
agent := &domain.Agent{
ID: "agent-001",
Name: "prod-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash123",
}
job1 := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusPending,
CreatedAt: now,
}
job2 := &domain.Job{
ID: "job-002",
Type: domain.JobTypeRenewal,
CertificateID: "cert-002",
Status: domain.JobStatusPending,
CreatedAt: now,
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
jobs, err := agentService.GetPendingWork(ctx, "agent-001")
if err != nil {
t.Fatalf("GetPendingWork failed: %v", err)
}
if len(jobs) != 1 {
t.Errorf("expected 1 deployment job, got %d", len(jobs))
}
if jobs[0].Type != domain.JobTypeDeployment {
t.Errorf("expected JobTypeDeployment, got %s", jobs[0].Type)
}
}
func TestReportJobStatus(t *testing.T) {
ctx := context.Background()
now := time.Now()
agent := &domain.Agent{
ID: "agent-001",
Name: "prod-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash123",
}
job := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusRunning,
CreatedAt: now,
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job},
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
err := agentService.ReportJobStatus(ctx, "agent-001", "job-001", domain.JobStatusCompleted, "")
if err != nil {
t.Fatalf("ReportJobStatus failed: %v", err)
}
if jobRepo.StatusUpdates["job-001"] != domain.JobStatusCompleted {
t.Errorf("expected status Completed, got %s", jobRepo.StatusUpdates["job-001"])
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestMarkStaleAgentsOffline(t *testing.T) {
ctx := context.Background()
now := time.Now()
staleTime := now.Add(-3 * time.Hour)
agent1 := &domain.Agent{
ID: "agent-001",
Name: "online-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash1",
}
agent2 := &domain.Agent{
ID: "agent-002",
Name: "stale-agent",
Hostname: "server-02",
Status: domain.AgentStatusOnline,
RegisteredAt: now.Add(-24 * time.Hour),
LastHeartbeatAt: &staleTime,
APIKeyHash: "hash2",
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent1, "agent-002": agent2},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
err := agentService.MarkStaleAgentsOffline(ctx, 1*time.Hour)
if err != nil {
t.Fatalf("MarkStaleAgentsOffline failed: %v", err)
}
if agentRepo.Agents["agent-001"].Status != domain.AgentStatusOnline {
t.Errorf("expected agent-001 to be Online, got %s", agentRepo.Agents["agent-001"].Status)
}
if agentRepo.Agents["agent-002"].Status != domain.AgentStatusOffline {
t.Errorf("expected agent-002 to be Offline, got %s", agentRepo.Agents["agent-002"].Status)
}
}
func TestSubmitCSR(t *testing.T) {
ctx := context.Background()
now := time.Now()
agent := &domain.Agent{
ID: "agent-001",
Name: "prod-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash123",
}
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-local",
Status: domain.CertificateStatusPending,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
issuerConnector := &mockIssuerConnector{
Result: &IssuanceResult{
Serial: "serial-123",
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
NotBefore: now,
NotAfter: now.AddDate(1, 0, 0),
},
}
issuerRegistry := map[string]IssuerConnector{"iss-local": issuerConnector}
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\ntest-csr\n-----END CERTIFICATE REQUEST-----"
err := agentService.SubmitCSR(ctx, "agent-001", "cert-001", []byte(csrPEM))
if err != nil {
t.Fatalf("SubmitCSR failed: %v", err)
}
if len(certRepo.Versions["cert-001"]) != 1 {
t.Errorf("expected 1 certificate version, got %d", len(certRepo.Versions["cert-001"]))
}
if cert.Status != domain.CertificateStatusActive {
t.Errorf("expected certificate status Active, got %s", cert.Status)
}
}
func TestSubmitCSR_EmptyCSR(t *testing.T) {
ctx := context.Background()
now := time.Now()
agent := &domain.Agent{
ID: "agent-001",
Name: "prod-agent",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash123",
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
err := agentService.SubmitCSR(ctx, "agent-001", "", []byte{})
if err == nil {
t.Fatal("expected error for empty CSR")
}
}
func TestListAgents(t *testing.T) {
now := time.Now()
agent1 := &domain.Agent{
ID: "agent-001",
Name: "agent1",
Hostname: "server-01",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash1",
}
agent2 := &domain.Agent{
ID: "agent-002",
Name: "agent2",
Hostname: "server-02",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "hash2",
}
agentRepo := &mockAgentRepo{
Agents: map[string]*domain.Agent{"agent-001": agent1, "agent-002": agent2},
HeartbeatUpdates: make(map[string]time.Time),
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
targetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
issuerRegistry := make(map[string]IssuerConnector)
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
agents, total, err := agentService.ListAgents(1, 50)
if err != nil {
t.Fatalf("ListAgents failed: %v", err)
}
if len(agents) != 2 {
t.Errorf("expected 2 agents, got %d", len(agents))
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
}
+329
View File
@@ -0,0 +1,329 @@
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
func TestRecordEvent(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
err := service.RecordEvent(ctx, "user123", domain.ActorTypeUser, "certificate_created", "certificate", "cert-001", map[string]interface{}{"common_name": "example.com"})
if err != nil {
t.Fatalf("RecordEvent failed: %v", err)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 event, got %d", len(auditRepo.Events))
}
event := auditRepo.Events[0]
if event.Actor != "user123" {
t.Errorf("expected actor user123, got %s", event.Actor)
}
if event.ActorType != domain.ActorTypeUser {
t.Errorf("expected actor type User, got %s", event.ActorType)
}
if event.Action != "certificate_created" {
t.Errorf("expected action certificate_created, got %s", event.Action)
}
if event.ResourceType != "certificate" {
t.Errorf("expected resource type certificate, got %s", event.ResourceType)
}
if event.ResourceID != "cert-001" {
t.Errorf("expected resource ID cert-001, got %s", event.ResourceID)
}
}
func TestRecordEvent_RepoError(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
CreateErr: errNotFound,
}
service := NewAuditService(auditRepo)
err := service.RecordEvent(ctx, "user123", domain.ActorTypeUser, "test_action", "resource", "res-001", map[string]interface{}{})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestListByResource(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
event1 := &domain.AuditEvent{
ID: "audit-1",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "created",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: time.Now(),
}
event2 := &domain.AuditEvent{
ID: "audit-2",
Actor: "user2",
ActorType: domain.ActorTypeUser,
Action: "updated",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: time.Now(),
}
event3 := &domain.AuditEvent{
ID: "audit-3",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "created",
ResourceType: "certificate",
ResourceID: "cert-002",
Timestamp: time.Now(),
}
auditRepo.AddEvent(event1)
auditRepo.AddEvent(event2)
auditRepo.AddEvent(event3)
events, err := service.ListByResource(ctx, "certificate", "cert-001")
if err != nil {
t.Fatalf("ListByResource failed: %v", err)
}
if len(events) != 2 {
t.Errorf("expected 2 events, got %d", len(events))
}
}
func TestListByActor(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
event1 := &domain.AuditEvent{
ID: "audit-1",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "created",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: time.Now(),
}
event2 := &domain.AuditEvent{
ID: "audit-2",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "updated",
ResourceType: "certificate",
ResourceID: "cert-002",
Timestamp: time.Now(),
}
event3 := &domain.AuditEvent{
ID: "audit-3",
Actor: "user2",
ActorType: domain.ActorTypeUser,
Action: "created",
ResourceType: "certificate",
ResourceID: "cert-003",
Timestamp: time.Now(),
}
auditRepo.AddEvent(event1)
auditRepo.AddEvent(event2)
auditRepo.AddEvent(event3)
events, err := service.ListByActor(ctx, "user1")
if err != nil {
t.Fatalf("ListByActor failed: %v", err)
}
if len(events) != 2 {
t.Errorf("expected 2 events, got %d", len(events))
}
}
func TestListByAction(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
now := time.Now()
from := now.Add(-1 * time.Hour)
to := now.Add(1 * time.Hour)
event1 := &domain.AuditEvent{
ID: "audit-1",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "certificate_created",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: now.Add(-30 * time.Minute),
}
event2 := &domain.AuditEvent{
ID: "audit-2",
Actor: "user2",
ActorType: domain.ActorTypeUser,
Action: "certificate_created",
ResourceType: "certificate",
ResourceID: "cert-002",
Timestamp: now.Add(-20 * time.Minute),
}
event3 := &domain.AuditEvent{
ID: "audit-3",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "certificate_updated",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: now.Add(-10 * time.Minute),
}
auditRepo.AddEvent(event1)
auditRepo.AddEvent(event2)
auditRepo.AddEvent(event3)
events, err := service.ListByAction(ctx, "certificate_created", from, to)
if err != nil {
t.Fatalf("ListByAction failed: %v", err)
}
if len(events) != 2 {
t.Errorf("expected 2 events, got %d", len(events))
}
for _, e := range events {
if e.Action != "certificate_created" {
t.Errorf("expected action certificate_created, got %s", e.Action)
}
}
}
func TestListByAction_EmptyRange(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
now := time.Now()
from := now.Add(1 * time.Hour)
to := now.Add(2 * time.Hour)
event := &domain.AuditEvent{
ID: "audit-1",
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "certificate_created",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: now.Add(-30 * time.Minute),
}
auditRepo.AddEvent(event)
events, err := service.ListByAction(ctx, "certificate_created", from, to)
if err != nil {
t.Fatalf("ListByAction failed: %v", err)
}
if len(events) != 0 {
t.Errorf("expected 0 events, got %d", len(events))
}
}
func TestRecordEvent_ComplexDetails(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
details := map[string]interface{}{
"common_name": "example.com",
"sans": []string{"www.example.com", "api.example.com"},
"issuer_id": "iss-123",
"count": 5,
}
err := service.RecordEvent(ctx, "user1", domain.ActorTypeUser, "certificate_created", "certificate", "cert-001", details)
if err != nil {
t.Fatalf("RecordEvent failed: %v", err)
}
event := auditRepo.Events[0]
var decoded map[string]interface{}
err = json.Unmarshal(event.Details, &decoded)
if err != nil {
t.Fatalf("failed to unmarshal details: %v", err)
}
if decoded["common_name"] != "example.com" {
t.Errorf("expected common_name example.com, got %v", decoded["common_name"])
}
}
func TestList(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
service := NewAuditService(auditRepo)
for i := 0; i < 5; i++ {
event := &domain.AuditEvent{
ID: "audit-" + string(rune(i)),
Actor: "user1",
ActorType: domain.ActorTypeUser,
Action: "test",
ResourceType: "certificate",
ResourceID: "cert-001",
Timestamp: time.Now(),
}
auditRepo.AddEvent(event)
}
filter := &repository.AuditFilter{
Page: 1,
PerPage: 10,
}
events, err := service.List(ctx, filter)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 5 {
t.Errorf("expected 5 events, got %d", len(events))
}
}
func TestList_RepoError(t *testing.T) {
ctx := context.Background()
auditRepo := &mockAuditRepo{
ListErr: errNotFound,
}
service := NewAuditService(auditRepo)
filter := &repository.AuditFilter{}
_, err := service.List(ctx, filter)
if err == nil {
t.Fatal("expected error, got nil")
}
}
+383
View File
@@ -0,0 +1,383 @@
package service
import (
"context"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
func TestCreateCertificate(t *testing.T) {
ctx := context.Background()
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{
Events: []*domain.AuditEvent{},
}
policyRepo := &mockPolicyRepo{
Rules: make(map[string]*domain.PolicyRule),
Violations: []*domain.PolicyViolation{},
}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-001",
Name: "api-prod",
CommonName: "api.example.com",
SANs: []string{"api.example.com"},
Environment: "production",
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-acme",
TargetIDs: []string{"target-1"},
RenewalPolicyID: "policy-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
Tags: map[string]string{"env": "prod"},
CreatedAt: now,
UpdatedAt: now,
}
err := certService.Create(ctx, cert, "user-1")
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if len(certRepo.Certs) != 1 {
t.Errorf("expected 1 cert, got %d", len(certRepo.Certs))
}
storedCert, ok := certRepo.Certs["cert-001"]
if !ok {
t.Fatal("certificate not stored")
}
if storedCert.CommonName != "api.example.com" {
t.Errorf("expected common name api.example.com, got %s", storedCert.CommonName)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestCreateCertificate_MissingRequired(t *testing.T) {
ctx := context.Background()
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
cert := &domain.ManagedCertificate{
ID: "cert-001",
// Missing CommonName and IssuerID
}
err := certService.Create(ctx, cert, "user-1")
if err == nil {
t.Fatal("expected error for missing required fields")
}
}
func TestGetCertificate(t *testing.T) {
ctx := context.Background()
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
retrieved, err := certService.Get(ctx, "cert-001")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if retrieved.CommonName != "example.com" {
t.Errorf("expected common name example.com, got %s", retrieved.CommonName)
}
}
func TestGetCertificate_NotFound(t *testing.T) {
ctx := context.Background()
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
_, err := certService.Get(ctx, "nonexistent")
if err == nil {
t.Fatal("expected error for nonexistent certificate")
}
}
func TestUpdateCertificate(t *testing.T) {
ctx := context.Background()
now := time.Now()
originalCert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": originalCert},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
updatedCert := *originalCert
updatedCert.Status = domain.CertificateStatusExpiring
updatedCert.ExpiresAt = now.AddDate(0, 0, 5)
err := certService.Update(ctx, &updatedCert, "user-1")
if err != nil {
t.Fatalf("Update failed: %v", err)
}
stored := certRepo.Certs["cert-001"]
if stored.Status != domain.CertificateStatusExpiring {
t.Errorf("expected status Expiring, got %s", stored.Status)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestArchiveCertificate(t *testing.T) {
ctx := context.Background()
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.Archive(ctx, "cert-001", "user-1")
if err != nil {
t.Fatalf("Archive failed: %v", err)
}
archived := certRepo.Certs["cert-001"]
if archived.Status != domain.CertificateStatusArchived {
t.Errorf("expected status Archived, got %s", archived.Status)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestGetVersions(t *testing.T) {
ctx := context.Background()
now := time.Now()
version1 := &domain.CertificateVersion{
ID: "ver-1",
CertificateID: "cert-001",
SerialNumber: "serial-1",
NotBefore: now.AddDate(-1, 0, 0),
NotAfter: now,
PEMChain: "cert1-pem",
CreatedAt: now.AddDate(-1, 0, 0),
}
version2 := &domain.CertificateVersion{
ID: "ver-2",
CertificateID: "cert-001",
SerialNumber: "serial-2",
NotBefore: now,
NotAfter: now.AddDate(1, 0, 0),
PEMChain: "cert2-pem",
CreatedAt: now,
}
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: map[string][]*domain.CertificateVersion{"cert-001": {version1, version2}},
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
versions, err := certService.GetVersions(ctx, "cert-001")
if err != nil {
t.Fatalf("GetVersions failed: %v", err)
}
if len(versions) != 2 {
t.Errorf("expected 2 versions, got %d", len(versions))
}
}
func TestTriggerRenewal(t *testing.T) {
ctx := context.Background()
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-1",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(0, 0, 5),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
if err != nil {
t.Fatalf("TriggerRenewal failed: %v", err)
}
renewed := certRepo.Certs["cert-001"]
if renewed.Status != domain.CertificateStatusRenewalInProgress {
t.Errorf("expected status RenewalInProgress, got %s", renewed.Status)
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestTriggerRenewal_Archived(t *testing.T) {
ctx := context.Background()
now := time.Now()
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-1",
Status: domain.CertificateStatusArchived,
ExpiresAt: now.AddDate(0, 0, 5),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
if err == nil {
t.Fatal("expected error for archived certificate")
}
}
func TestListCertificates(t *testing.T) {
now := time.Now()
cert1 := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "api.example.com",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
cert2 := &domain.ManagedCertificate{
ID: "cert-002",
CommonName: "web.example.com",
Status: domain.CertificateStatusExpiring,
ExpiresAt: now.AddDate(0, 0, 5),
CreatedAt: now,
UpdatedAt: now,
}
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert1, "cert-002": cert2},
Versions: make(map[string][]*domain.CertificateVersion),
}
auditRepo := &mockAuditRepo{}
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService)
certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50)
if err != nil {
t.Fatalf("ListCertificates failed: %v", err)
}
if len(certs) != 2 {
t.Errorf("expected 2 certs, got %d", len(certs))
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
}
+244
View File
@@ -0,0 +1,244 @@
package service
import (
"context"
"log/slog"
"os"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// helper to build job service with proper constructor signatures
func newTestJobService(jobRepo *mockJobRepo) *JobService {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
certRepo := &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
renewalPolicyRepo := &mockRenewalPolicyRepo{
Policies: make(map[string]*domain.RenewalPolicy),
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
notifRepo := newMockNotificationRepository()
notifService := NewNotificationService(notifRepo, make(map[string]Notifier))
targetRepo := &mockTargetRepo{Targets: make(map[string]*domain.DeploymentTarget)}
agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent)}
renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notifService, make(map[string]IssuerConnector))
deploymentService := NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notifService)
return NewJobService(jobRepo, renewalService, deploymentService, logger)
}
func TestProcessPendingJobs_Renewal(t *testing.T) {
ctx := context.Background()
now := time.Now()
job := &domain.Job{
ID: "job-001",
Type: domain.JobTypeRenewal,
CertificateID: "cert-001",
Status: domain.JobStatusPending,
Attempts: 0,
MaxAttempts: 3,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
err := jobService.ProcessPendingJobs(ctx)
if err != nil {
t.Logf("ProcessPendingJobs returned error (expected for renewal without cert): %v", err)
}
}
func TestProcessPendingJobs_NoJobs(t *testing.T) {
ctx := context.Background()
jobRepo := &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
err := jobService.ProcessPendingJobs(ctx)
if err != nil {
t.Fatalf("ProcessPendingJobs failed: %v", err)
}
}
func TestCancelJob(t *testing.T) {
ctx := context.Background()
now := time.Now()
job := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusPending,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
err := jobService.CancelJobWithContext(ctx, "job-001")
if err != nil {
t.Fatalf("CancelJob failed: %v", err)
}
if jobRepo.StatusUpdates["job-001"] != domain.JobStatusCancelled {
t.Errorf("expected status Cancelled, got %s", jobRepo.StatusUpdates["job-001"])
}
}
func TestCancelJob_AlreadyCompleted(t *testing.T) {
ctx := context.Background()
now := time.Now()
job := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusCompleted,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
err := jobService.CancelJobWithContext(ctx, "job-001")
if err == nil {
t.Fatal("expected error for completed job")
}
}
func TestGetJob(t *testing.T) {
now := time.Now()
job := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusPending,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
retrieved, err := jobService.GetJob("job-001")
if err != nil {
t.Fatalf("GetJob failed: %v", err)
}
if retrieved.ID != "job-001" {
t.Errorf("expected job ID job-001, got %s", retrieved.ID)
}
if retrieved.Type != domain.JobTypeDeployment {
t.Errorf("expected job type Deployment, got %s", retrieved.Type)
}
}
func TestListJobs(t *testing.T) {
now := time.Now()
job1 := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusCompleted,
CreatedAt: now,
ScheduledAt: now,
}
job2 := &domain.Job{
ID: "job-002",
Type: domain.JobTypeRenewal,
CertificateID: "cert-002",
Status: domain.JobStatusPending,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
jobs, total, err := jobService.ListJobs("", "", 1, 50)
if err != nil {
t.Fatalf("ListJobs failed: %v", err)
}
if len(jobs) != 2 {
t.Errorf("expected 2 jobs, got %d", len(jobs))
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
}
func TestListJobs_FilterByStatus(t *testing.T) {
now := time.Now()
job1 := &domain.Job{
ID: "job-001",
Type: domain.JobTypeDeployment,
CertificateID: "cert-001",
Status: domain.JobStatusCompleted,
CreatedAt: now,
ScheduledAt: now,
}
job2 := &domain.Job{
ID: "job-002",
Type: domain.JobTypeRenewal,
CertificateID: "cert-002",
Status: domain.JobStatusPending,
CreatedAt: now,
ScheduledAt: now,
}
jobRepo := &mockJobRepo{
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
StatusUpdates: make(map[string]domain.JobStatus),
}
jobService := newTestJobService(jobRepo)
jobs, total, err := jobService.ListJobs(string(domain.JobStatusPending), "", 1, 50)
if err != nil {
t.Fatalf("ListJobs failed: %v", err)
}
if len(jobs) != 1 {
t.Errorf("expected 1 pending job, got %d", len(jobs))
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
}
+567
View File
@@ -0,0 +1,567 @@
package service
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
func TestSendThresholdAlert(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-test-1",
CommonName: "example.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(0, 0, 5),
}
threshold := 7
daysUntilExpiry := 5
err := svc.SendThresholdAlert(ctx, cert, daysUntilExpiry, threshold)
if err != nil {
t.Fatalf("SendThresholdAlert failed: %v", err)
}
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
notif := notifRepo.Notifications[0]
if notif.Type != domain.NotificationTypeExpirationWarning {
t.Errorf("expected ExpirationWarning, got %s", notif.Type)
}
// Verify message contains threshold tag
if !strings.Contains(notif.Message, "[threshold:7]") {
t.Errorf("expected threshold tag in message, got: %s", notif.Message)
}
// Verify notifier was called
if notifier.getSentCount() != 1 {
t.Errorf("expected 1 sent message, got %d", notifier.getSentCount())
}
}
func TestSendThresholdAlert_Expired(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-test-expired",
CommonName: "expired.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(0, 0, -1),
}
threshold := 0
daysUntilExpiry := -1
err := svc.SendThresholdAlert(ctx, cert, daysUntilExpiry, threshold)
if err != nil {
t.Fatalf("SendThresholdAlert failed: %v", err)
}
// Verify message contains [EXPIRED] prefix
if len(notifRepo.Notifications) > 0 && !strings.Contains(notifRepo.Notifications[0].Message, "[EXPIRED]") {
t.Errorf("expected [EXPIRED] in message, got: %s", notifRepo.Notifications[0].Message)
}
}
func TestHasThresholdNotification_Found(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Add an existing notification with threshold tag
existingNotif := &domain.NotificationEvent{
ID: "notif-1",
CertificateID: stringPtr("mc-test-1"),
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: "Certificate expires soon\n\n[threshold:30]",
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(existingNotif)
// Check for existing notification
found, err := svc.HasThresholdNotification(ctx, "mc-test-1", 30)
if err != nil {
t.Fatalf("HasThresholdNotification failed: %v", err)
}
if !found {
t.Errorf("expected to find threshold notification, but didn't")
}
}
func TestHasThresholdNotification_NotFound(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Check for non-existent notification
found, err := svc.HasThresholdNotification(ctx, "mc-test-1", 30)
if err != nil {
t.Fatalf("HasThresholdNotification failed: %v", err)
}
if found {
t.Errorf("expected not to find threshold notification, but did")
}
}
func TestSendExpirationWarning(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-test-warning",
CommonName: "warn.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(0, 0, 10),
}
err := svc.SendExpirationWarning(ctx, cert, 10)
if err != nil {
t.Fatalf("SendExpirationWarning failed: %v", err)
}
// Verify notification was created
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
if notifRepo.Notifications[0].Type != domain.NotificationTypeExpirationWarning {
t.Errorf("expected ExpirationWarning type, got %s", notifRepo.Notifications[0].Type)
}
}
func TestSendRenewalNotification_Success(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-renewed",
CommonName: "renewed.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(1, 0, 0),
}
err := svc.SendRenewalNotification(ctx, cert, true, nil)
if err != nil {
t.Fatalf("SendRenewalNotification failed: %v", err)
}
// Verify notification was created with success type
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
if notifRepo.Notifications[0].Type != domain.NotificationTypeRenewalSuccess {
t.Errorf("expected RenewalSuccess type, got %s", notifRepo.Notifications[0].Type)
}
// Verify message contains success text
if !strings.Contains(notifRepo.Notifications[0].Message, "successfully renewed") {
t.Errorf("expected success message, got: %s", notifRepo.Notifications[0].Message)
}
}
func TestSendRenewalNotification_Failure(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-failed-renewal",
CommonName: "failed.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(0, 0, 5),
}
testErr := fmt.Errorf("issuer unavailable")
err := svc.SendRenewalNotification(ctx, cert, false, testErr)
if err != nil {
t.Fatalf("SendRenewalNotification failed: %v", err)
}
// Verify notification was created with failure type
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
if notifRepo.Notifications[0].Type != domain.NotificationTypeRenewalFailure {
t.Errorf("expected RenewalFailure type, got %s", notifRepo.Notifications[0].Type)
}
// Verify message contains error info
if !strings.Contains(notifRepo.Notifications[0].Message, "failed to renew") {
t.Errorf("expected failure message, got: %s", notifRepo.Notifications[0].Message)
}
}
func TestProcessPendingNotifications(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
// Add pending notifications
for i := 0; i < 3; i++ {
notif := &domain.NotificationEvent{
ID: fmt.Sprintf("notif-%d", i),
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: fmt.Sprintf("Test notification %d", i),
Status: "pending",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
}
err := svc.ProcessPendingNotifications(ctx)
if err != nil {
t.Fatalf("ProcessPendingNotifications failed: %v", err)
}
// Verify all notifications were sent
if notifier.getSentCount() != 3 {
t.Errorf("expected 3 sent notifications, got %d", notifier.getSentCount())
}
// Verify status was updated to sent
for _, notif := range notifRepo.Notifications {
if notif.Status != "sent" {
t.Errorf("expected notification status 'sent', got %s", notif.Status)
}
}
}
func TestProcessPendingNotifications_NoNotifier(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
// No notifier registered - demo mode
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Add pending notification
notif := &domain.NotificationEvent{
ID: "notif-demo",
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail, // Channel not in registry
Recipient: "owner-1",
Message: "Test notification",
Status: "pending",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
// Should not fail, just mark as sent (demo mode graceful skip)
err := svc.ProcessPendingNotifications(ctx)
if err != nil {
t.Fatalf("ProcessPendingNotifications should not fail in demo mode: %v", err)
}
// Status should still be updated to sent
if len(notifRepo.Notifications) > 0 && notifRepo.Notifications[0].Status == "sent" {
// This is fine - graceful skip marks as sent
}
}
func TestRegisterNotifier(t *testing.T) {
t.Helper()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
notifier := newMockNotifier()
svc.RegisterNotifier("Email", notifier)
// Verify notifier was registered
if svc.notifierRegistry["Email"] == nil {
t.Errorf("expected notifier to be registered")
}
}
func TestListNotifications(t *testing.T) {
t.Helper()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Add test notifications
for i := 0; i < 5; i++ {
notif := &domain.NotificationEvent{
ID: fmt.Sprintf("notif-list-%d", i),
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: fmt.Sprintf("owner-%d", i%2),
Message: fmt.Sprintf("Test notification %d", i),
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
}
// List with pagination
notifs, total, err := svc.ListNotifications(1, 3)
if err != nil {
t.Fatalf("ListNotifications failed: %v", err)
}
if len(notifs) == 0 {
t.Errorf("expected notifications, got none")
}
if total == 0 {
t.Errorf("expected total count > 0, got %d", total)
}
}
func TestMarkAsRead(t *testing.T) {
t.Helper()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Add a notification
notif := &domain.NotificationEvent{
ID: "notif-read",
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: "Test notification",
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
// Mark as read
err := svc.MarkAsRead(notif.ID)
if err != nil {
t.Fatalf("MarkAsRead failed: %v", err)
}
// Verify status was updated
if len(notifRepo.Notifications) > 0 && notifRepo.Notifications[0].Status != "read" {
t.Errorf("expected status 'read', got %s", notifRepo.Notifications[0].Status)
}
}
func TestGetNotification(t *testing.T) {
t.Helper()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
// Add a notification
notif := &domain.NotificationEvent{
ID: "notif-get-test",
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: "Test notification",
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
// Get the notification
retrieved, err := svc.GetNotification(notif.ID)
if err != nil {
t.Fatalf("GetNotification failed: %v", err)
}
if retrieved == nil {
t.Errorf("expected notification, got nil")
} else if retrieved.ID != notif.ID {
t.Errorf("expected ID %s, got %s", notif.ID, retrieved.ID)
}
}
func TestSendDeploymentNotification_Success(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-deploy",
CommonName: "deploy.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(1, 0, 0),
}
target := &domain.DeploymentTarget{
ID: "target-1",
Name: "NGINX-Prod",
}
err := svc.SendDeploymentNotification(ctx, cert, target, true, nil)
if err != nil {
t.Fatalf("SendDeploymentNotification failed: %v", err)
}
// Verify notification was created
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
if notifRepo.Notifications[0].Type != domain.NotificationTypeDeploymentSuccess {
t.Errorf("expected DeploymentSuccess type, got %s", notifRepo.Notifications[0].Type)
}
}
func TestSendDeploymentNotification_Failure(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
registry := map[string]Notifier{
"Email": notifier,
}
svc := NewNotificationService(notifRepo, registry)
cert := &domain.ManagedCertificate{
ID: "mc-deploy-fail",
CommonName: "deploy-fail.com",
OwnerID: "owner-1",
ExpiresAt: time.Now().AddDate(1, 0, 0),
}
target := &domain.DeploymentTarget{
ID: "target-2",
Name: "NGINX-Staging",
}
deployErr := fmt.Errorf("connection timeout")
err := svc.SendDeploymentNotification(ctx, cert, target, false, deployErr)
if err != nil {
t.Fatalf("SendDeploymentNotification failed: %v", err)
}
// Verify notification was created
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
}
if notifRepo.Notifications[0].Type != domain.NotificationTypeDeploymentFailure {
t.Errorf("expected DeploymentFailure type, got %s", notifRepo.Notifications[0].Type)
}
}
func TestGetNotificationHistory(t *testing.T) {
t.Helper()
ctx := context.Background()
notifRepo := newMockNotificationRepository()
registry := map[string]Notifier{}
svc := NewNotificationService(notifRepo, registry)
certID := "mc-history"
// Add multiple notifications for same cert
for i := 0; i < 3; i++ {
notif := &domain.NotificationEvent{
ID: fmt.Sprintf("notif-hist-%d", i),
CertificateID: &certID,
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: fmt.Sprintf("Alert %d", i),
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(notif)
}
// Get history
history, err := svc.GetNotificationHistory(ctx, certID)
if err != nil {
t.Fatalf("GetNotificationHistory failed: %v", err)
}
if len(history) < 1 {
t.Errorf("expected at least 1 notification, got %d", len(history))
}
}
// Helper function
func stringPtr(s string) *string {
return &s
}
+422
View File
@@ -0,0 +1,422 @@
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
func TestCreateRule(t *testing.T) {
ctx := context.Background()
policyRepo := &mockPolicyRepo{
Rules: make(map[string]*domain.PolicyRule),
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
config := map[string]interface{}{"issuers": []string{"iss-acme"}}
configJSON, _ := json.Marshal(config)
rule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Config: configJSON,
Enabled: true,
}
err := policyService.CreateRule(ctx, rule, "user-1")
if err != nil {
t.Fatalf("CreateRule failed: %v", err)
}
if len(policyRepo.Rules) != 1 {
t.Errorf("expected 1 rule, got %d", len(policyRepo.Rules))
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestGetRule(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
retrieved, err := policyService.GetRule(ctx, "rule-001")
if err != nil {
t.Fatalf("GetRule failed: %v", err)
}
if retrieved.Name != "Allowed Issuers" {
t.Errorf("expected name Allowed Issuers, got %s", retrieved.Name)
}
}
func TestGetRule_NotFound(t *testing.T) {
ctx := context.Background()
policyRepo := &mockPolicyRepo{
Rules: make(map[string]*domain.PolicyRule),
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
_, err := policyService.GetRule(ctx, "nonexistent")
if err == nil {
t.Fatal("expected error for nonexistent rule")
}
}
func TestListRules(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule1 := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
rule2 := &domain.PolicyRule{
ID: "rule-002",
Name: "Required Metadata",
Type: domain.PolicyTypeRequiredMetadata,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
rules, err := policyService.ListRules(ctx)
if err != nil {
t.Fatalf("ListRules failed: %v", err)
}
if len(rules) != 2 {
t.Errorf("expected 2 rules, got %d", len(rules))
}
}
func TestUpdateRule(t *testing.T) {
ctx := context.Background()
now := time.Now()
originalRule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": originalRule},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
updatedRule := *originalRule
updatedRule.Enabled = false
err := policyService.UpdateRule(ctx, &updatedRule, "user-1")
if err != nil {
t.Fatalf("UpdateRule failed: %v", err)
}
stored := policyRepo.Rules["rule-001"]
if stored.Enabled {
t.Error("expected rule to be disabled")
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestDeleteRule(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
err := policyService.DeleteRule(ctx, "rule-001", "user-1")
if err != nil {
t.Fatalf("DeleteRule failed: %v", err)
}
if len(policyRepo.Rules) != 0 {
t.Errorf("expected 0 rules, got %d", len(policyRepo.Rules))
}
if len(auditRepo.Events) != 1 {
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
}
}
func TestValidateCertificate(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "iss-acme",
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
violations, err := policyService.ValidateCertificate(ctx, cert)
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) > 0 {
t.Errorf("expected no violations, got %d", len(violations))
}
}
func TestValidateCertificate_WithViolation(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "", // Missing issuer
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
violations, err := policyService.ValidateCertificate(ctx, cert)
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) != 1 {
t.Errorf("expected 1 violation, got %d", len(violations))
}
if violations[0].CertificateID != "cert-001" {
t.Errorf("expected violation for cert-001, got %s", violations[0].CertificateID)
}
}
func TestValidateCertificate_MultipleViolations(t *testing.T) {
ctx := context.Background()
now := time.Now()
rule1 := &domain.PolicyRule{
ID: "rule-001",
Name: "Allowed Issuers",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
rule2 := &domain.PolicyRule{
ID: "rule-002",
Name: "Required Metadata",
Type: domain.PolicyTypeRequiredMetadata,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
cert := &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
IssuerID: "", // Missing issuer
Tags: nil, // Missing metadata
Status: domain.CertificateStatusActive,
ExpiresAt: now.AddDate(1, 0, 0),
CreatedAt: now,
UpdatedAt: now,
}
violations, err := policyService.ValidateCertificate(ctx, cert)
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) != 2 {
t.Errorf("expected 2 violations, got %d", len(violations))
}
}
func TestListPolicies(t *testing.T) {
now := time.Now()
rule1 := &domain.PolicyRule{
ID: "rule-001",
Name: "Rule 1",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
rule2 := &domain.PolicyRule{
ID: "rule-002",
Name: "Rule 2",
Type: domain.PolicyTypeRequiredMetadata,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
policies, total, err := policyService.ListPolicies(1, 50)
if err != nil {
t.Fatalf("ListPolicies failed: %v", err)
}
if len(policies) != 2 {
t.Errorf("expected 2 policies, got %d", len(policies))
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
}
func TestCreatePolicy(t *testing.T) {
now := time.Now()
policyRepo := &mockPolicyRepo{
Rules: make(map[string]*domain.PolicyRule),
Violations: []*domain.PolicyViolation{},
}
auditRepo := &mockAuditRepo{}
auditService := NewAuditService(auditRepo)
policyService := NewPolicyService(policyRepo, auditService)
policy := domain.PolicyRule{
Name: "Test Policy",
Type: domain.PolicyTypeAllowedIssuers,
Enabled: true,
CreatedAt: now,
}
created, err := policyService.CreatePolicy(policy)
if err != nil {
t.Fatalf("CreatePolicy failed: %v", err)
}
if created.ID == "" {
t.Fatal("expected non-empty policy ID")
}
if len(policyRepo.Rules) != 1 {
t.Errorf("expected 1 rule in repo, got %d", len(policyRepo.Rules))
}
}
+866
View File
@@ -0,0 +1,866 @@
package service
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
func TestCheckExpiringCertificates_SendsThresholdAlerts(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
"Email": notifier,
})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create a cert expiring in 10 days
cert := &domain.ManagedCertificate{
ID: "mc-expiring",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 10),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy with thresholds
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run expiry check
err := svc.CheckExpiringCertificates(ctx)
if err != nil {
t.Fatalf("CheckExpiringCertificates failed: %v", err)
}
// Verify alerts were sent
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected at least 1 alert, got %d", len(notifRepo.Notifications))
}
// Verify renewal job was created
if len(jobRepo.Jobs) < 1 {
t.Errorf("expected renewal job to be created")
}
hasRenewalJob := false
for _, job := range jobRepo.Jobs {
if job.Type == domain.JobTypeRenewal {
hasRenewalJob = true
break
}
}
if !hasRenewalJob {
t.Errorf("expected renewal job in jobs")
}
}
func TestCheckExpiringCertificates_DeduplicatesAlerts(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
notifier := newMockNotifier()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
"Email": notifier,
})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create cert
cert := &domain.ManagedCertificate{
ID: "mc-dedup",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 10),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Add existing threshold alert notification
existingNotif := &domain.NotificationEvent{
ID: "notif-existing",
CertificateID: &cert.ID,
Type: domain.NotificationTypeExpirationWarning,
Channel: domain.NotificationChannelEmail,
Recipient: "owner-1",
Message: "Alert [threshold:7]",
Status: "sent",
CreatedAt: time.Now(),
}
notifRepo.AddNotification(existingNotif)
// Run first check
_ = svc.CheckExpiringCertificates(ctx)
initialCount := notifier.getSentCount()
// Run second check - should deduplicate
_ = svc.CheckExpiringCertificates(ctx)
finalCount := notifier.getSentCount()
// Should not send duplicate alerts
if finalCount > initialCount {
t.Errorf("expected deduplication, but sent new alerts: initial=%d, final=%d", initialCount, finalCount)
}
}
func TestCheckExpiringCertificates_SkipsRenewalInProgress(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create cert with RenewalInProgress status
cert := &domain.ManagedCertificate{
ID: "mc-in-progress",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusRenewalInProgress,
ExpiresAt: time.Now().AddDate(0, 0, 10),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run check
err := svc.CheckExpiringCertificates(ctx)
if err != nil {
t.Fatalf("CheckExpiringCertificates failed: %v", err)
}
// Should not create renewal job for cert already renewing
for _, job := range jobRepo.Jobs {
if job.Type == domain.JobTypeRenewal {
t.Errorf("should not create renewal job for cert with RenewalInProgress status")
}
}
}
func TestCheckExpiringCertificates_UpdatesStatusToExpiring(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create active cert that will become expiring
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
cert := &domain.ManagedCertificate{
ID: "mc-expiring-status",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-unregistered",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 5), // 5 days, within 30-day threshold
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy with AutoRenew: false so we only test status transition
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: false,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run check
_ = svc.CheckExpiringCertificates(ctx)
// Verify status was updated to Expiring
updated, _ := certRepo.Get(ctx, cert.ID)
if updated.Status != domain.CertificateStatusExpiring {
t.Errorf("expected status Expiring, got %s", updated.Status)
}
}
func TestCheckExpiringCertificates_UpdatesStatusToExpired(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create cert that is already expired
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
cert := &domain.ManagedCertificate{
ID: "mc-expired-status",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-unregistered",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, -1), // Already expired
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy with AutoRenew: false so we only test status transition
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: false,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run check
_ = svc.CheckExpiringCertificates(ctx)
// Verify status was updated to Expired
updated, _ := certRepo.Get(ctx, cert.ID)
if updated.Status != domain.CertificateStatusExpired {
t.Errorf("expected status Expired, got %s", updated.Status)
}
}
func TestCheckExpiringCertificates_CreatesRenewalJob(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create expiring cert with registered issuer
cert := &domain.ManagedCertificate{
ID: "mc-job-create",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test", // Registered issuer
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 20),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run check
_ = svc.CheckExpiringCertificates(ctx)
// Verify renewal job was created
hasRenewalJob := false
for _, job := range jobRepo.Jobs {
if job.Type == domain.JobTypeRenewal && job.Status == domain.JobStatusPending {
hasRenewalJob = true
break
}
}
if !hasRenewalJob {
t.Errorf("expected renewal job to be created")
}
}
func TestCheckExpiringCertificates_SkipsWithoutIssuer(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
// Empty issuer registry
issuerRegistry := map[string]IssuerConnector{}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create cert with unregistered issuer
cert := &domain.ManagedCertificate{
ID: "mc-no-issuer",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-missing", // Not in registry
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 20),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Run check
_ = svc.CheckExpiringCertificates(ctx)
// Should not create renewal job without issuer
for _, job := range jobRepo.Jobs {
if job.Type == domain.JobTypeRenewal {
t.Errorf("should not create renewal job for cert with missing issuer")
}
}
}
func TestCheckExpiringCertificates_SkipsDuplicateJobs(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create cert
cert := &domain.ManagedCertificate{
ID: "mc-dup-job",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 20),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create policy
policy := &domain.RenewalPolicy{
ID: "rp-standard",
Name: "Standard",
RenewalWindowDays: 30,
AutoRenew: true,
MaxRetries: 3,
RetryInterval: 300,
AlertThresholdsDays: []int{30, 14, 7, 0},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
policyRepo.AddPolicy(policy)
// Add existing renewal job
existingJob := &domain.Job{
ID: "job-existing",
CertificateID: cert.ID,
Type: domain.JobTypeRenewal,
Status: domain.JobStatusPending,
MaxAttempts: 3,
ScheduledAt: time.Now(),
CreatedAt: time.Now(),
}
jobRepo.AddJob(existingJob)
// Run first check
_ = svc.CheckExpiringCertificates(ctx)
// Run second check
_ = svc.CheckExpiringCertificates(ctx)
// Should have only 1 renewal job
renewalCount := 0
for _, job := range jobRepo.Jobs {
if job.Type == domain.JobTypeRenewal {
renewalCount++
}
}
if renewalCount > 1 {
t.Errorf("expected 1 renewal job, got %d (duplicate prevention failed)", renewalCount)
}
}
func TestProcessRenewalJob(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
"Email": newMockNotifier(),
})
issuerConnector := &mockIssuerConnector{}
issuerRegistry := map[string]IssuerConnector{
"iss-test": issuerConnector,
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create certificate
cert := &domain.ManagedCertificate{
ID: "mc-renewal",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{"www.test.example.com"},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
TargetIDs: []string{"target-1", "target-2"},
ExpiresAt: time.Now().AddDate(0, 0, 30),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create renewal job
job := &domain.Job{
ID: "job-renewal-1",
CertificateID: cert.ID,
Type: domain.JobTypeRenewal,
Status: domain.JobStatusPending,
MaxAttempts: 3,
ScheduledAt: time.Now(),
CreatedAt: time.Now(),
}
jobRepo.AddJob(job)
// Process renewal job
err := svc.ProcessRenewalJob(ctx, job)
if err != nil {
t.Fatalf("ProcessRenewalJob failed: %v", err)
}
// Verify cert was updated
updated, _ := certRepo.Get(ctx, cert.ID)
if updated.Status != domain.CertificateStatusActive {
t.Errorf("expected cert status Active, got %s", updated.Status)
}
if updated.LastRenewalAt == nil {
t.Errorf("expected LastRenewalAt to be set")
}
// Verify certificate version was created
if len(certRepo.Versions[cert.ID]) != 1 {
t.Errorf("expected 1 certificate version, got %d", len(certRepo.Versions[cert.ID]))
}
// Verify deployment jobs were created
deploymentCount := 0
for _, j := range jobRepo.Jobs {
if j.Type == domain.JobTypeDeployment {
deploymentCount++
}
}
if deploymentCount != 2 {
t.Errorf("expected 2 deployment jobs (one per target), got %d", deploymentCount)
}
// Verify job was marked as completed
completedJob, _ := jobRepo.Get(ctx, job.ID)
if completedJob.Status != domain.JobStatusCompleted {
t.Errorf("expected job status Completed, got %s", completedJob.Status)
}
}
func TestProcessRenewalJob_IssuerFailure(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
"Email": newMockNotifier(),
})
// Create issuer that will fail
issuerConnector := &mockIssuerConnector{
Err: fmt.Errorf("issuer service unavailable"),
}
issuerRegistry := map[string]IssuerConnector{
"iss-test": issuerConnector,
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create certificate
cert := &domain.ManagedCertificate{
ID: "mc-renewal-fail",
Name: "Test Cert",
CommonName: "test.example.com",
SANs: []string{},
OwnerID: "owner-1",
TeamID: "team-1",
IssuerID: "iss-test",
RenewalPolicyID: "rp-standard",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Tags: make(map[string]string),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
certRepo.AddCert(cert)
// Create renewal job
job := &domain.Job{
ID: "job-renewal-fail",
CertificateID: cert.ID,
Type: domain.JobTypeRenewal,
Status: domain.JobStatusPending,
MaxAttempts: 3,
ScheduledAt: time.Now(),
CreatedAt: time.Now(),
}
jobRepo.AddJob(job)
// Process renewal job (should fail)
err := svc.ProcessRenewalJob(ctx, job)
if err == nil {
t.Fatalf("expected ProcessRenewalJob to fail")
}
// Verify job was marked as failed
failedJob, _ := jobRepo.Get(ctx, job.ID)
if failedJob.Status != domain.JobStatusFailed {
t.Errorf("expected job status Failed, got %s", failedJob.Status)
}
if failedJob.LastError == nil || !strings.Contains(*failedJob.LastError, "issuer service unavailable") {
t.Errorf("expected error message in job, got: %v", failedJob.LastError)
}
// Verify failure notification was sent
if len(notifRepo.Notifications) < 1 {
t.Errorf("expected failure notification to be created")
}
foundFailureNotif := false
for _, notif := range notifRepo.Notifications {
if notif.Type == domain.NotificationTypeRenewalFailure {
foundFailureNotif = true
break
}
}
if !foundFailureNotif {
t.Errorf("expected RenewalFailure notification type")
}
}
func TestRetryFailedJobs(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create failed job with attempts < max_attempts
failedJob := &domain.Job{
ID: "job-failed-1",
CertificateID: "mc-test",
Type: domain.JobTypeRenewal,
Status: domain.JobStatusFailed,
Attempts: 1,
MaxAttempts: 3,
LastError: stringPtr("temporary failure"),
ScheduledAt: time.Now(),
CreatedAt: time.Now().AddDate(0, 0, -1),
}
jobRepo.AddJob(failedJob)
// Create other job types that should be ignored
otherJob := &domain.Job{
ID: "job-other",
CertificateID: "mc-test",
Type: domain.JobTypeDeployment,
Status: domain.JobStatusFailed,
Attempts: 1,
MaxAttempts: 3,
ScheduledAt: time.Now(),
CreatedAt: time.Now(),
}
jobRepo.AddJob(otherJob)
// Retry failed jobs
err := svc.RetryFailedJobs(ctx, 3)
if err != nil {
t.Fatalf("RetryFailedJobs failed: %v", err)
}
// Verify failed renewal job was reset to pending
retried, _ := jobRepo.Get(ctx, failedJob.ID)
if retried.Status != domain.JobStatusPending {
t.Errorf("expected job status Pending after retry, got %s", retried.Status)
}
// Verify other job type was not touched
other, _ := jobRepo.Get(ctx, otherJob.ID)
if other.Status != domain.JobStatusFailed {
t.Errorf("expected non-renewal job to stay Failed, got %s", other.Status)
}
}
func TestProcessRenewalJob_NoCertificate(t *testing.T) {
t.Helper()
ctx := context.Background()
certRepo := newMockCertificateRepository()
jobRepo := newMockJobRepository()
policyRepo := newMockRenewalPolicyRepository()
auditRepo := newMockAuditRepository()
notifRepo := newMockNotificationRepository()
auditSvc := NewAuditService(auditRepo)
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
issuerRegistry := map[string]IssuerConnector{
"iss-test": &mockIssuerConnector{},
}
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
// Create job with non-existent certificate
job := &domain.Job{
ID: "job-no-cert",
CertificateID: "mc-missing",
Type: domain.JobTypeRenewal,
Status: domain.JobStatusPending,
MaxAttempts: 3,
ScheduledAt: time.Now(),
CreatedAt: time.Now(),
}
jobRepo.AddJob(job)
// Process renewal job
err := svc.ProcessRenewalJob(ctx, job)
if err == nil {
t.Fatalf("expected ProcessRenewalJob to fail for missing certificate")
}
// Verify job was marked as failed
failedJob, _ := jobRepo.Get(ctx, job.ID)
if failedJob.Status != domain.JobStatusFailed {
t.Errorf("expected job status Failed, got %s", failedJob.Status)
}
}
// stringPtr is defined in notification_test.go
+771
View File
@@ -0,0 +1,771 @@
package service
import (
"context"
"errors"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
var errNotFound = errors.New("not found")
// mockCertRepo is a test implementation of CertificateRepository
type mockCertRepo struct {
Certs map[string]*domain.ManagedCertificate
Versions map[string][]*domain.CertificateVersion
CreateErr error
UpdateErr error
GetErr error
ListErr error
ListVersionsErr error
ListVersionsResult []*domain.CertificateVersion
CreateVersionErr error
ArchiveErr error
}
func (m *mockCertRepo) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
if m.ListErr != nil {
return nil, 0, m.ListErr
}
var certs []*domain.ManagedCertificate
for _, c := range m.Certs {
certs = append(certs, c)
}
return certs, len(certs), nil
}
func (m *mockCertRepo) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
cert, ok := m.Certs[id]
if !ok {
return nil, errNotFound
}
return cert, nil
}
func (m *mockCertRepo) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Certs[cert.ID] = cert
return nil
}
func (m *mockCertRepo) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.Certs[cert.ID] = cert
return nil
}
func (m *mockCertRepo) Archive(ctx context.Context, id string) error {
if m.ArchiveErr != nil {
return m.ArchiveErr
}
cert, ok := m.Certs[id]
if !ok {
return errNotFound
}
cert.Status = domain.CertificateStatusArchived
return nil
}
func (m *mockCertRepo) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
if m.ListVersionsErr != nil {
return nil, m.ListVersionsErr
}
if m.ListVersionsResult != nil {
return m.ListVersionsResult, nil
}
return m.Versions[certID], nil
}
func (m *mockCertRepo) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
if m.CreateVersionErr != nil {
return m.CreateVersionErr
}
m.Versions[version.CertificateID] = append(m.Versions[version.CertificateID], version)
return nil
}
func (m *mockCertRepo) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
var expiring []*domain.ManagedCertificate
for _, c := range m.Certs {
if c.ExpiresAt.Before(before) {
expiring = append(expiring, c)
}
}
return expiring, nil
}
func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) {
m.Certs[cert.ID] = cert
}
// mockJobRepo is a test implementation of JobRepository
type mockJobRepo struct {
Jobs map[string]*domain.Job
StatusUpdates map[string]domain.JobStatus
CreateErr error
UpdateErr error
UpdateStatusErr error
GetErr error
ListErr error
ListByStatusErr error
DeleteErr error
}
func (m *mockJobRepo) List(ctx context.Context) ([]*domain.Job, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var jobs []*domain.Job
for _, j := range m.Jobs {
jobs = append(jobs, j)
}
return jobs, nil
}
func (m *mockJobRepo) Get(ctx context.Context, id string) (*domain.Job, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
job, ok := m.Jobs[id]
if !ok {
return nil, errNotFound
}
return job, nil
}
func (m *mockJobRepo) Create(ctx context.Context, job *domain.Job) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Jobs[job.ID] = job
return nil
}
func (m *mockJobRepo) Update(ctx context.Context, job *domain.Job) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.Jobs[job.ID] = job
return nil
}
func (m *mockJobRepo) Delete(ctx context.Context, id string) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
delete(m.Jobs, id)
return nil
}
func (m *mockJobRepo) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
if m.ListByStatusErr != nil {
return nil, m.ListByStatusErr
}
var jobs []*domain.Job
for _, j := range m.Jobs {
if j.Status == status {
jobs = append(jobs, j)
}
}
return jobs, nil
}
func (m *mockJobRepo) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.Jobs {
if j.CertificateID == certID {
jobs = append(jobs, j)
}
}
return jobs, nil
}
func (m *mockJobRepo) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
if m.UpdateStatusErr != nil {
return m.UpdateStatusErr
}
job, ok := m.Jobs[id]
if !ok {
return errNotFound
}
job.Status = status
if errMsg != "" {
job.LastError = &errMsg
}
m.StatusUpdates[id] = status
return nil
}
func (m *mockJobRepo) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
var jobs []*domain.Job
for _, j := range m.Jobs {
if j.Type == jobType && j.Status == domain.JobStatusPending {
jobs = append(jobs, j)
}
}
return jobs, nil
}
func (m *mockJobRepo) AddJob(job *domain.Job) {
m.Jobs[job.ID] = job
}
// mockNotifRepo is a test implementation of NotificationRepository
type mockNotifRepo struct {
Notifications []*domain.NotificationEvent
CreateErr error
ListErr error
UpdateErr error
}
func (m *mockNotifRepo) Create(ctx context.Context, notif *domain.NotificationEvent) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Notifications = append(m.Notifications, notif)
return nil
}
func (m *mockNotifRepo) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
return m.Notifications, nil
}
func (m *mockNotifRepo) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
for _, n := range m.Notifications {
if n.ID == id {
n.Status = status
return nil
}
}
return errNotFound
}
func (m *mockNotifRepo) AddNotification(notif *domain.NotificationEvent) {
m.Notifications = append(m.Notifications, notif)
}
// mockAuditRepo is a test implementation of AuditRepository
type mockAuditRepo struct {
Events []*domain.AuditEvent
CreateErr error
ListErr error
}
func (m *mockAuditRepo) Create(ctx context.Context, event *domain.AuditEvent) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Events = append(m.Events, event)
return nil
}
func (m *mockAuditRepo) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
// Apply filtering like the real repo
var filtered []*domain.AuditEvent
for _, e := range m.Events {
if filter != nil {
if filter.ResourceType != "" && e.ResourceType != filter.ResourceType {
continue
}
if filter.ResourceID != "" && e.ResourceID != filter.ResourceID {
continue
}
if filter.Actor != "" && e.Actor != filter.Actor {
continue
}
if !filter.From.IsZero() && e.Timestamp.Before(filter.From) {
continue
}
if !filter.To.IsZero() && e.Timestamp.After(filter.To) {
continue
}
}
filtered = append(filtered, e)
}
return filtered, nil
}
func (m *mockAuditRepo) AddEvent(event *domain.AuditEvent) {
m.Events = append(m.Events, event)
}
// mockPolicyRepo is a test implementation of PolicyRepository
type mockPolicyRepo struct {
Rules map[string]*domain.PolicyRule
Violations []*domain.PolicyViolation
CreateRuleErr error
UpdateRuleErr error
DeleteRuleErr error
GetRuleErr error
ListRulesErr error
CreateViolationErr error
ListViolationsErr error
}
func (m *mockPolicyRepo) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
if m.ListRulesErr != nil {
return nil, m.ListRulesErr
}
var rules []*domain.PolicyRule
for _, r := range m.Rules {
rules = append(rules, r)
}
return rules, nil
}
func (m *mockPolicyRepo) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
if m.GetRuleErr != nil {
return nil, m.GetRuleErr
}
rule, ok := m.Rules[id]
if !ok {
return nil, errNotFound
}
return rule, nil
}
func (m *mockPolicyRepo) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
if m.CreateRuleErr != nil {
return m.CreateRuleErr
}
m.Rules[rule.ID] = rule
return nil
}
func (m *mockPolicyRepo) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
if m.UpdateRuleErr != nil {
return m.UpdateRuleErr
}
m.Rules[rule.ID] = rule
return nil
}
func (m *mockPolicyRepo) DeleteRule(ctx context.Context, id string) error {
if m.DeleteRuleErr != nil {
return m.DeleteRuleErr
}
delete(m.Rules, id)
return nil
}
func (m *mockPolicyRepo) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
if m.CreateViolationErr != nil {
return m.CreateViolationErr
}
m.Violations = append(m.Violations, violation)
return nil
}
func (m *mockPolicyRepo) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
if m.ListViolationsErr != nil {
return nil, m.ListViolationsErr
}
return m.Violations, nil
}
func (m *mockPolicyRepo) AddRule(rule *domain.PolicyRule) {
m.Rules[rule.ID] = rule
}
// mockRenewalPolicyRepo is a test implementation of RenewalPolicyRepository
type mockRenewalPolicyRepo struct {
Policies map[string]*domain.RenewalPolicy
GetErr error
ListErr error
}
func (m *mockRenewalPolicyRepo) Get(ctx context.Context, id string) (*domain.RenewalPolicy, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
policy, ok := m.Policies[id]
if !ok {
return nil, errNotFound
}
return policy, nil
}
func (m *mockRenewalPolicyRepo) List(ctx context.Context) ([]*domain.RenewalPolicy, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var policies []*domain.RenewalPolicy
for _, p := range m.Policies {
policies = append(policies, p)
}
return policies, nil
}
func (m *mockRenewalPolicyRepo) AddPolicy(policy *domain.RenewalPolicy) {
m.Policies[policy.ID] = policy
}
// mockAgentRepo is a test implementation of AgentRepository
type mockAgentRepo struct {
Agents map[string]*domain.Agent
HeartbeatUpdates map[string]time.Time
CreateErr error
UpdateErr error
DeleteErr error
GetErr error
ListErr error
UpdateHeartbeatErr error
GetByAPIKeyErr error
}
func (m *mockAgentRepo) List(ctx context.Context) ([]*domain.Agent, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var agents []*domain.Agent
for _, a := range m.Agents {
agents = append(agents, a)
}
return agents, nil
}
func (m *mockAgentRepo) Get(ctx context.Context, id string) (*domain.Agent, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
agent, ok := m.Agents[id]
if !ok {
return nil, errNotFound
}
return agent, nil
}
func (m *mockAgentRepo) Create(ctx context.Context, agent *domain.Agent) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Agents[agent.ID] = agent
return nil
}
func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.Agents[agent.ID] = agent
return nil
}
func (m *mockAgentRepo) Delete(ctx context.Context, id string) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
delete(m.Agents, id)
return nil
}
func (m *mockAgentRepo) UpdateHeartbeat(ctx context.Context, id string) error {
if m.UpdateHeartbeatErr != nil {
return m.UpdateHeartbeatErr
}
agent, ok := m.Agents[id]
if !ok {
return errNotFound
}
now := time.Now()
agent.LastHeartbeatAt = &now
m.HeartbeatUpdates[id] = now
return nil
}
func (m *mockAgentRepo) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
if m.GetByAPIKeyErr != nil {
return nil, m.GetByAPIKeyErr
}
for _, a := range m.Agents {
if a.APIKeyHash == keyHash {
return a, nil
}
}
return nil, errNotFound
}
func (m *mockAgentRepo) AddAgent(agent *domain.Agent) {
m.Agents[agent.ID] = agent
}
// mockTargetRepo is a test implementation of TargetRepository
type mockTargetRepo struct {
Targets map[string]*domain.DeploymentTarget
CreateErr error
UpdateErr error
DeleteErr error
GetErr error
ListErr error
ListByCertErr error
}
func (m *mockTargetRepo) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var targets []*domain.DeploymentTarget
for _, t := range m.Targets {
targets = append(targets, t)
}
return targets, nil
}
func (m *mockTargetRepo) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
target, ok := m.Targets[id]
if !ok {
return nil, errNotFound
}
return target, nil
}
func (m *mockTargetRepo) Create(ctx context.Context, target *domain.DeploymentTarget) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.Targets[target.ID] = target
return nil
}
func (m *mockTargetRepo) Update(ctx context.Context, target *domain.DeploymentTarget) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.Targets[target.ID] = target
return nil
}
func (m *mockTargetRepo) Delete(ctx context.Context, id string) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
delete(m.Targets, id)
return nil
}
func (m *mockTargetRepo) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
if m.ListByCertErr != nil {
return nil, m.ListByCertErr
}
return m.List(ctx)
}
func (m *mockTargetRepo) AddTarget(target *domain.DeploymentTarget) {
m.Targets[target.ID] = target
}
// mockIssuerConnector is a test implementation of IssuerConnector
type mockIssuerConnector struct {
Result *IssuanceResult
Err error
}
func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error) {
if m.Err != nil {
return nil, m.Err
}
if m.Result != nil {
return m.Result, nil
}
now := time.Now()
return &IssuanceResult{
Serial: "test-serial-123",
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
NotBefore: now,
NotAfter: now.AddDate(1, 0, 0),
}, nil
}
func (m *mockIssuerConnector) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error) {
if m.Err != nil {
return nil, m.Err
}
return m.IssueCertificate(ctx, commonName, sans, csrPEM)
}
// Constructor functions for mocks
func newMockCertificateRepository() *mockCertRepo {
return &mockCertRepo{
Certs: make(map[string]*domain.ManagedCertificate),
Versions: make(map[string][]*domain.CertificateVersion),
}
}
func newMockJobRepository() *mockJobRepo {
return &mockJobRepo{
Jobs: make(map[string]*domain.Job),
StatusUpdates: make(map[string]domain.JobStatus),
}
}
func newMockNotificationRepository() *mockNotifRepo {
return &mockNotifRepo{
Notifications: make([]*domain.NotificationEvent, 0),
}
}
func newMockAuditRepository() *mockAuditRepo {
return &mockAuditRepo{
Events: make([]*domain.AuditEvent, 0),
}
}
func newMockPolicyRepository() *mockPolicyRepo {
return &mockPolicyRepo{
Rules: make(map[string]*domain.PolicyRule),
Violations: make([]*domain.PolicyViolation, 0),
}
}
func newMockRenewalPolicyRepository() *mockRenewalPolicyRepo {
return &mockRenewalPolicyRepo{
Policies: make(map[string]*domain.RenewalPolicy),
}
}
func newMockAgentRepository() *mockAgentRepo {
return &mockAgentRepo{
Agents: make(map[string]*domain.Agent),
HeartbeatUpdates: make(map[string]time.Time),
}
}
func newMockTargetRepository() *mockTargetRepo {
return &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget),
}
}
func newMockIssuerRepository() *mockIssuerRepository {
return &mockIssuerRepository{
issuers: make(map[string]*domain.Issuer),
}
}
// mockIssuerRepository is a test implementation of IssuerRepository
type mockIssuerRepository struct {
issuers map[string]*domain.Issuer
GetErr error
ListErr error
CreateErr error
UpdateErr error
DeleteErr error
}
func (m *mockIssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
if m.ListErr != nil {
return nil, m.ListErr
}
var issuers []*domain.Issuer
for _, i := range m.issuers {
issuers = append(issuers, i)
}
return issuers, nil
}
func (m *mockIssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
if m.GetErr != nil {
return nil, m.GetErr
}
issuer, ok := m.issuers[id]
if !ok {
return nil, errNotFound
}
return issuer, nil
}
func (m *mockIssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
if m.CreateErr != nil {
return m.CreateErr
}
m.issuers[issuer.ID] = issuer
return nil
}
func (m *mockIssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
if m.UpdateErr != nil {
return m.UpdateErr
}
m.issuers[issuer.ID] = issuer
return nil
}
func (m *mockIssuerRepository) Delete(ctx context.Context, id string) error {
if m.DeleteErr != nil {
return m.DeleteErr
}
delete(m.issuers, id)
return nil
}
func (m *mockIssuerRepository) AddIssuer(issuer *domain.Issuer) {
m.issuers[issuer.ID] = issuer
}
// mockNotifier is a simple notifier for testing
type mockNotifier struct {
messages []*mockNotifierMessage
SendErr error
}
type mockNotifierMessage struct {
Recipient string
Subject string
Body string
}
func newMockNotifier() *mockNotifier {
return &mockNotifier{
messages: make([]*mockNotifierMessage, 0),
}
}
func (m *mockNotifier) Send(ctx context.Context, recipient string, subject string, body string) error {
if m.SendErr != nil {
return m.SendErr
}
m.messages = append(m.messages, &mockNotifierMessage{
Recipient: recipient,
Subject: subject,
Body: body,
})
return nil
}
func (m *mockNotifier) Channel() string {
return "Email"
}
func (m *mockNotifier) getSentCount() int {
return len(m.messages)
}
func (m *mockNotifier) getLastMessage() *mockNotifierMessage {
if len(m.messages) == 0 {
return nil
}
return m.messages[len(m.messages)-1]
}