mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 13:51:36 +00:00
Implement M4: comprehensive test coverage with 120 tests
Service layer (63 tests): certificate, agent, audit, job, notification, policy, and renewal services with mock repositories covering threshold alerting, deduplication, status transitions, and job processing. Handler layer (46 tests): certificate and agent HTTP handlers using httptest with mock service interfaces, covering success/error paths, pagination, JSON marshaling, and path parameter extraction. Integration (11 subtests): end-to-end certificate lifecycle test exercising real services and Local CA issuer through HTTP API — create cert, trigger renewal, process jobs, register agent, heartbeat, verify audit trail. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,869 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// MockAgentService is a mock implementation of AgentService interface.
|
||||
type MockAgentService struct {
|
||||
ListAgentsFn func(page, perPage int) ([]domain.Agent, int64, error)
|
||||
GetAgentFn func(id string) (*domain.Agent, error)
|
||||
RegisterAgentFn func(agent domain.Agent) (*domain.Agent, error)
|
||||
HeartbeatFn func(agentID string) error
|
||||
CSRSubmitFn func(agentID string, csrPEM string) (string, error)
|
||||
CSRSubmitForCertFn func(agentID string, certID string, csrPEM string) (string, error)
|
||||
CertificatePickupFn func(agentID, certID string) (string, error)
|
||||
GetWorkFn func(agentID string) ([]domain.Job, error)
|
||||
GetWorkWithTargetsFn func(agentID string) ([]domain.WorkItem, error)
|
||||
UpdateJobStatusFn func(agentID string, jobID string, status string, errMsg string) error
|
||||
}
|
||||
|
||||
func (m *MockAgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) {
|
||||
if m.ListAgentsFn != nil {
|
||||
return m.ListAgentsFn(page, perPage)
|
||||
}
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) GetAgent(id string) (*domain.Agent, error) {
|
||||
if m.GetAgentFn != nil {
|
||||
return m.GetAgentFn(id)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) {
|
||||
if m.RegisterAgentFn != nil {
|
||||
return m.RegisterAgentFn(agent)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) Heartbeat(agentID string) error {
|
||||
if m.HeartbeatFn != nil {
|
||||
return m.HeartbeatFn(agentID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) CSRSubmit(agentID string, csrPEM string) (string, error) {
|
||||
if m.CSRSubmitFn != nil {
|
||||
return m.CSRSubmitFn(agentID, csrPEM)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) {
|
||||
if m.CSRSubmitForCertFn != nil {
|
||||
return m.CSRSubmitForCertFn(agentID, certID, csrPEM)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) CertificatePickup(agentID, certID string) (string, error) {
|
||||
if m.CertificatePickupFn != nil {
|
||||
return m.CertificatePickupFn(agentID, certID)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) GetWork(agentID string) ([]domain.Job, error) {
|
||||
if m.GetWorkFn != nil {
|
||||
return m.GetWorkFn(agentID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) {
|
||||
if m.GetWorkWithTargetsFn != nil {
|
||||
return m.GetWorkWithTargetsFn(agentID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error {
|
||||
if m.UpdateJobStatusFn != nil {
|
||||
return m.UpdateJobStatusFn(agentID, jobID, status, errMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test ListAgents - success case
|
||||
func TestListAgents_Success(t *testing.T) {
|
||||
now := time.Now()
|
||||
agent1 := domain.Agent{
|
||||
ID: "a-prod-001",
|
||||
Name: "Production Agent",
|
||||
Hostname: "prod-server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
LastHeartbeatAt: &now,
|
||||
RegisteredAt: now,
|
||||
}
|
||||
agent2 := domain.Agent{
|
||||
ID: "a-prod-002",
|
||||
Name: "API Agent",
|
||||
Hostname: "api-server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
LastHeartbeatAt: &now,
|
||||
RegisteredAt: now,
|
||||
}
|
||||
|
||||
mock := &MockAgentService{
|
||||
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
|
||||
if page == 1 && perPage == 50 {
|
||||
return []domain.Agent{agent1, agent2}, 2, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents?page=1&per_page=50", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response PagedResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Total != 2 {
|
||||
t.Errorf("expected total 2, got %d", response.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListAgents - method not allowed
|
||||
func TestListAgents_MethodNotAllowed(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListAgents - service error
|
||||
func TestListAgents_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
|
||||
return nil, 0, ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetAgent - success case
|
||||
func TestGetAgent_Success(t *testing.T) {
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "a-prod-001",
|
||||
Name: "Production Agent",
|
||||
Hostname: "prod-server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
LastHeartbeatAt: &now,
|
||||
RegisteredAt: now,
|
||||
}
|
||||
|
||||
mock := &MockAgentService{
|
||||
GetAgentFn: func(id string) (*domain.Agent, error) {
|
||||
if id == "a-prod-001" {
|
||||
return agent, nil
|
||||
}
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response domain.Agent
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.ID != "a-prod-001" {
|
||||
t.Errorf("expected ID a-prod-001, got %s", response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetAgent - not found
|
||||
func TestGetAgent_NotFound(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
GetAgentFn: func(id string) (*domain.Agent, error) {
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/nonexistent", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test RegisterAgent - success case
|
||||
func TestRegisterAgent_Success(t *testing.T) {
|
||||
now := time.Now()
|
||||
registered := &domain.Agent{
|
||||
ID: "a-prod-001",
|
||||
Name: "Production Agent",
|
||||
Hostname: "prod-server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
}
|
||||
|
||||
mock := &MockAgentService{
|
||||
RegisterAgentFn: func(agent domain.Agent) (*domain.Agent, error) {
|
||||
return registered, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
agentBody := domain.Agent{
|
||||
Name: "Production Agent",
|
||||
Hostname: "prod-server-01",
|
||||
}
|
||||
body, _ := json.Marshal(agentBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.RegisterAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
|
||||
}
|
||||
|
||||
var response domain.Agent
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.ID != "a-prod-001" {
|
||||
t.Errorf("expected ID a-prod-001, got %s", response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test RegisterAgent - invalid body
|
||||
func TestRegisterAgent_InvalidBody(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader([]byte("invalid json")))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.RegisterAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Heartbeat - success case
|
||||
func TestHeartbeat_Success(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
HeartbeatFn: func(agentID string) error {
|
||||
if agentID == "a-prod-001" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/heartbeat", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.Heartbeat(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "heartbeat_recorded" {
|
||||
t.Errorf("expected status 'heartbeat_recorded', got %s", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Heartbeat - service error
|
||||
func TestHeartbeat_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
HeartbeatFn: func(agentID string) error {
|
||||
return ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/heartbeat", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.Heartbeat(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCSRSubmit - with certificate_id
|
||||
func TestAgentCSRSubmit_WithCertificateID(t *testing.T) {
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----"
|
||||
|
||||
mock := &MockAgentService{
|
||||
CSRSubmitForCertFn: func(agentID string, certID string, csrPEM string) (string, error) {
|
||||
if agentID == "a-prod-001" && certID == "mc-prod-001" {
|
||||
return "csr_submitted", nil
|
||||
}
|
||||
return "", ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
reqBody := map[string]string{
|
||||
"csr_pem": csrPEM,
|
||||
"certificate_id": "mc-prod-001",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCSRSubmit(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "csr_submitted" {
|
||||
t.Errorf("expected status 'csr_submitted', got %s", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCSRSubmit - without certificate_id
|
||||
func TestAgentCSRSubmit_WithoutCertificateID(t *testing.T) {
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----"
|
||||
|
||||
mock := &MockAgentService{
|
||||
CSRSubmitFn: func(agentID string, csrPEM string) (string, error) {
|
||||
if agentID == "a-prod-001" {
|
||||
return "csr_submitted", nil
|
||||
}
|
||||
return "", ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
reqBody := map[string]string{
|
||||
"csr_pem": csrPEM,
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCSRSubmit(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCSRSubmit - missing CSR PEM
|
||||
func TestAgentCSRSubmit_MissingCSRPEM(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
reqBody := map[string]string{
|
||||
"certificate_id": "mc-prod-001",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCSRSubmit(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCSRSubmit - invalid body
|
||||
func TestAgentCSRSubmit_InvalidBody(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader([]byte("invalid")))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCSRSubmit(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCertificatePickup - success case
|
||||
func TestAgentCertificatePickup_Success(t *testing.T) {
|
||||
certPEM := "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"
|
||||
|
||||
mock := &MockAgentService{
|
||||
CertificatePickupFn: func(agentID, certID string) (string, error) {
|
||||
if agentID == "a-prod-001" && certID == "mc-prod-001" {
|
||||
return certPEM, nil
|
||||
}
|
||||
return "", ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
// Path structure: /api/v1/agents/{agent_id}/certificates/{cert_id}
|
||||
// After trim and split: parts[0]="agent_id", parts[1]="certificates", parts[2]="cert_id", parts[3]=""
|
||||
// Note: handler checks len(parts) < 4, so we need the trailing slash
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/certificates/mc-prod-001/", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCertificatePickup(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d (body: %s)", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["certificate_pem"] != certPEM {
|
||||
t.Errorf("expected cert PEM %s, got %s", certPEM, response["certificate_pem"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCertificatePickup - not found
|
||||
func TestAgentCertificatePickup_NotFound(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
CertificatePickupFn: func(agentID, certID string) (string, error) {
|
||||
return "", ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/certificates/nonexistent/", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCertificatePickup(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d (body: %s)", http.StatusNotFound, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentGetWork - success with items
|
||||
func TestAgentGetWork_Success(t *testing.T) {
|
||||
workItem := domain.WorkItem{
|
||||
ID: "j-deploy-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "mc-prod-001",
|
||||
TargetID: stringPtr("t-nginx-001"),
|
||||
TargetType: "NGINX",
|
||||
Status: domain.JobStatusPending,
|
||||
}
|
||||
|
||||
mock := &MockAgentService{
|
||||
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
|
||||
if agentID == "a-prod-001" {
|
||||
return []domain.WorkItem{workItem}, nil
|
||||
}
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentGetWork(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["count"] != float64(1) {
|
||||
t.Errorf("expected count 1, got %v", response["count"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentGetWork - no work items
|
||||
func TestAgentGetWork_NoItems(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentGetWork(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["count"] != float64(0) {
|
||||
t.Errorf("expected count 0, got %v", response["count"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentGetWork - service error
|
||||
func TestAgentGetWork_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
GetWorkWithTargetsFn: func(agentID string) ([]domain.WorkItem, error) {
|
||||
return nil, ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/a-prod-001/work", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentGetWork(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentReportJobStatus - success case
|
||||
func TestAgentReportJobStatus_Success(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
|
||||
if agentID == "a-prod-001" && jobID == "j-deploy-001" && status == "Completed" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
statusReq := map[string]string{
|
||||
"status": "Completed",
|
||||
}
|
||||
body, _ := json.Marshal(statusReq)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentReportJobStatus(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "updated" {
|
||||
t.Errorf("expected status 'updated', got %s", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentReportJobStatus - with error message
|
||||
func TestAgentReportJobStatus_WithError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
|
||||
if agentID == "a-prod-001" && jobID == "j-deploy-001" && status == "Failed" && errMsg == "timeout" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
statusReq := map[string]string{
|
||||
"status": "Failed",
|
||||
"error": "timeout",
|
||||
}
|
||||
body, _ := json.Marshal(statusReq)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentReportJobStatus(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentReportJobStatus - missing status
|
||||
func TestAgentReportJobStatus_MissingStatus(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
statusReq := map[string]string{}
|
||||
body, _ := json.Marshal(statusReq)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentReportJobStatus(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentReportJobStatus - invalid body
|
||||
func TestAgentReportJobStatus_InvalidBody(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader([]byte("invalid")))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentReportJobStatus(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListAgents - invalid pagination parameters
|
||||
func TestListAgents_InvalidPagination(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
ListAgentsFn: func(page, perPage int) ([]domain.Agent, int64, error) {
|
||||
// Should default to page=1, perPage=50 if invalid
|
||||
if page == 1 && perPage == 50 {
|
||||
return []domain.Agent{}, 0, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents?page=invalid&per_page=invalid", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListAgents(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetAgent - empty ID
|
||||
func TestGetAgent_EmptyID(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/agents/", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test RegisterAgent - service error
|
||||
func TestRegisterAgent_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
RegisterAgentFn: func(agent domain.Agent) (*domain.Agent, error) {
|
||||
return nil, ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
agentBody := domain.Agent{
|
||||
Name: "Production Agent",
|
||||
Hostname: "prod-server-01",
|
||||
}
|
||||
body, _ := json.Marshal(agentBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.RegisterAgent(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Heartbeat - empty agent ID
|
||||
func TestHeartbeat_EmptyAgentID(t *testing.T) {
|
||||
mock := &MockAgentService{}
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents//heartbeat", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.Heartbeat(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentCSRSubmit - service error
|
||||
func TestAgentCSRSubmit_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
CSRSubmitFn: func(agentID string, csrPEM string) (string, error) {
|
||||
return "", ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
reqBody := map[string]string{
|
||||
"csr_pem": "-----BEGIN CERTIFICATE REQUEST-----\nMIIC...\n-----END CERTIFICATE REQUEST-----",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/csr", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentCSRSubmit(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test AgentReportJobStatus - service error
|
||||
func TestAgentReportJobStatus_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
UpdateJobStatusFn: func(agentID string, jobID string, status string, errMsg string) error {
|
||||
return ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewAgentHandler(mock)
|
||||
|
||||
statusReq := map[string]string{
|
||||
"status": "Completed",
|
||||
}
|
||||
body, _ := json.Marshal(statusReq)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/a-prod-001/jobs/j-deploy-001/status", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AgentReportJobStatus(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a string pointer
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,704 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// MockCertificateService is a mock implementation of CertificateService interface.
|
||||
type MockCertificateService struct {
|
||||
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
||||
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
|
||||
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||
ArchiveCertificateFn func(id string) error
|
||||
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
||||
TriggerRenewalFn func(certID string) error
|
||||
TriggerDeploymentFn func(certID string, targetID string) error
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
if m.ListCertificatesFn != nil {
|
||||
return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage)
|
||||
}
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
|
||||
if m.GetCertificateFn != nil {
|
||||
return m.GetCertificateFn(id)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||
if m.CreateCertificateFn != nil {
|
||||
return m.CreateCertificateFn(cert)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||
if m.UpdateCertificateFn != nil {
|
||||
return m.UpdateCertificateFn(id, cert)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) ArchiveCertificate(id string) error {
|
||||
if m.ArchiveCertificateFn != nil {
|
||||
return m.ArchiveCertificateFn(id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||
if m.GetCertificateVersionsFn != nil {
|
||||
return m.GetCertificateVersionsFn(certID, page, perPage)
|
||||
}
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) TriggerRenewal(certID string) error {
|
||||
if m.TriggerRenewalFn != nil {
|
||||
return m.TriggerRenewalFn(certID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error {
|
||||
if m.TriggerDeploymentFn != nil {
|
||||
return m.TriggerDeploymentFn(certID, targetID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to create context with request ID.
|
||||
func contextWithRequestID() context.Context {
|
||||
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
|
||||
}
|
||||
|
||||
// Test ListCertificates - success case
|
||||
func TestListCertificates_Success(t *testing.T) {
|
||||
cert1 := domain.ManagedCertificate{
|
||||
ID: "mc-prod-001",
|
||||
Name: "Production Cert",
|
||||
CommonName: "example.com",
|
||||
Status: domain.CertificateStatusActive,
|
||||
Environment: "prod",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
cert2 := domain.ManagedCertificate{
|
||||
ID: "mc-prod-002",
|
||||
Name: "API Cert",
|
||||
CommonName: "api.example.com",
|
||||
Status: domain.CertificateStatusActive,
|
||||
Environment: "prod",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
mock := &MockCertificateService{
|
||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
if page == 1 && perPage == 50 {
|
||||
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?page=1&per_page=50", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response PagedResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Total != 2 {
|
||||
t.Errorf("expected total 2, got %d", response.Total)
|
||||
}
|
||||
if response.Page != 1 {
|
||||
t.Errorf("expected page 1, got %d", response.Page)
|
||||
}
|
||||
if response.PerPage != 50 {
|
||||
t.Errorf("expected per_page 50, got %d", response.PerPage)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListCertificates - with filters
|
||||
func TestListCertificates_WithFilters(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
if status == "Active" && environment == "prod" {
|
||||
return []domain.ManagedCertificate{}, 0, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?status=Active&environment=prod&page=1&per_page=25", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListCertificates - invalid method
|
||||
func TestListCertificates_MethodNotAllowed(t *testing.T) {
|
||||
mock := &MockCertificateService{}
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListCertificates - service error
|
||||
func TestListCertificates_ServiceError(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
return nil, 0, ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetCertificate - success case
|
||||
func TestGetCertificate_Success(t *testing.T) {
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-prod-001",
|
||||
Name: "Production Cert",
|
||||
CommonName: "example.com",
|
||||
Status: domain.CertificateStatusActive,
|
||||
Environment: "prod",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
mock := &MockCertificateService{
|
||||
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
|
||||
if id == "mc-prod-001" {
|
||||
return cert, nil
|
||||
}
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response domain.ManagedCertificate
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.ID != "mc-prod-001" {
|
||||
t.Errorf("expected ID mc-prod-001, got %s", response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetCertificate - not found
|
||||
func TestGetCertificate_NotFound(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/nonexistent", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetCertificate - empty ID
|
||||
func TestGetCertificate_EmptyID(t *testing.T) {
|
||||
mock := &MockCertificateService{}
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateCertificate - success case
|
||||
func TestCreateCertificate_Success(t *testing.T) {
|
||||
now := time.Now()
|
||||
created := &domain.ManagedCertificate{
|
||||
ID: "mc-prod-001",
|
||||
Name: "Production Cert",
|
||||
CommonName: "example.com",
|
||||
Status: domain.CertificateStatusPending,
|
||||
Environment: "prod",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
mock := &MockCertificateService{
|
||||
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||
return created, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
certBody := domain.ManagedCertificate{
|
||||
Name: "Production Cert",
|
||||
CommonName: "example.com",
|
||||
}
|
||||
body, _ := json.Marshal(certBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.CreateCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
|
||||
}
|
||||
|
||||
var response domain.ManagedCertificate
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.ID != "mc-prod-001" {
|
||||
t.Errorf("expected ID mc-prod-001, got %s", response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateCertificate - invalid request body
|
||||
func TestCreateCertificate_InvalidBody(t *testing.T) {
|
||||
mock := &MockCertificateService{}
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader([]byte("invalid json")))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.CreateCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CreateCertificate - service error
|
||||
func TestCreateCertificate_ServiceError(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||
return nil, ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
certBody := domain.ManagedCertificate{
|
||||
Name: "Production Cert",
|
||||
CommonName: "example.com",
|
||||
}
|
||||
body, _ := json.Marshal(certBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.CreateCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test UpdateCertificate - success case
|
||||
func TestUpdateCertificate_Success(t *testing.T) {
|
||||
updated := &domain.ManagedCertificate{
|
||||
ID: "mc-prod-001",
|
||||
Name: "Updated Cert",
|
||||
CommonName: "example.com",
|
||||
Status: domain.CertificateStatusActive,
|
||||
Environment: "prod",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
mock := &MockCertificateService{
|
||||
UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||
if id == "mc-prod-001" {
|
||||
return updated, nil
|
||||
}
|
||||
return nil, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
certBody := domain.ManagedCertificate{
|
||||
Name: "Updated Cert",
|
||||
}
|
||||
body, _ := json.Marshal(certBody)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/certificates/mc-prod-001", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.UpdateCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response domain.ManagedCertificate
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Name != "Updated Cert" {
|
||||
t.Errorf("expected name 'Updated Cert', got %s", response.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Test UpdateCertificate - invalid body
|
||||
func TestUpdateCertificate_InvalidBody(t *testing.T) {
|
||||
mock := &MockCertificateService{}
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/certificates/mc-prod-001", bytes.NewReader([]byte("invalid")))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.UpdateCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ArchiveCertificate - success case
|
||||
func TestArchiveCertificate_Success(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ArchiveCertificateFn: func(id string) error {
|
||||
if id == "mc-prod-001" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/mc-prod-001", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ArchiveCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNoContent, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ArchiveCertificate - not found
|
||||
func TestArchiveCertificate_NotFound(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ArchiveCertificateFn: func(id string) error {
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/nonexistent", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ArchiveCertificate(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetCertificateVersions - success case
|
||||
func TestGetCertificateVersions_Success(t *testing.T) {
|
||||
ver1 := domain.CertificateVersion{
|
||||
ID: "cv-001",
|
||||
CertificateID: "mc-prod-001",
|
||||
SerialNumber: "ABC123",
|
||||
FingerprintSHA256: "abc123...",
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 365),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
mock := &MockCertificateService{
|
||||
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||
if certID == "mc-prod-001" {
|
||||
return []domain.CertificateVersion{ver1}, 1, nil
|
||||
}
|
||||
return nil, 0, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001/versions?page=1&per_page=50", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetCertificateVersions(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response PagedResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Total != 1 {
|
||||
t.Errorf("expected total 1, got %d", response.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetCertificateVersions - not found
|
||||
func TestGetCertificateVersions_NotFound(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||
return nil, 0, ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/nonexistent/versions", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetCertificateVersions(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test TriggerRenewal - success case
|
||||
func TestTriggerRenewal_Success(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
TriggerRenewalFn: func(certID string) error {
|
||||
if certID == "mc-prod-001" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/renew", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.TriggerRenewal(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "renewal_triggered" {
|
||||
t.Errorf("expected status 'renewal_triggered', got %s", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test TriggerRenewal - service error
|
||||
func TestTriggerRenewal_ServiceError(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
TriggerRenewalFn: func(certID string) error {
|
||||
return ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/renew", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.TriggerRenewal(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test TriggerDeployment - success case
|
||||
func TestTriggerDeployment_Success(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
TriggerDeploymentFn: func(certID string, targetID string) error {
|
||||
if certID == "mc-prod-001" {
|
||||
return nil
|
||||
}
|
||||
return ErrMockNotFound
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
|
||||
deployReq := map[string]string{"target_id": "t-nginx-001"}
|
||||
body, _ := json.Marshal(deployReq)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deploy", bytes.NewReader(body))
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.TriggerDeployment(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "deployment_triggered" {
|
||||
t.Errorf("expected status 'deployment_triggered', got %s", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test TriggerDeployment - without target ID
|
||||
func TestTriggerDeployment_NoTargetID(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
TriggerDeploymentFn: func(certID string, targetID string) error {
|
||||
// Should accept empty targetID (deploy to all)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deploy", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.TriggerDeployment(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Errorf("expected status %d, got %d", http.StatusAccepted, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListCertificates - invalid page parameter
|
||||
func TestListCertificates_InvalidPageParam(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
// Should default to page 1
|
||||
if page == 1 {
|
||||
return []domain.ManagedCertificate{}, 0, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?page=invalid&per_page=50", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Test ListCertificates - per_page exceeds max
|
||||
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
||||
mock := &MockCertificateService{
|
||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||
// Should cap perPage at 500
|
||||
if perPage == 50 { // defaults to 50 if > 500
|
||||
return []domain.ManagedCertificate{}, 0, nil
|
||||
}
|
||||
return nil, 0, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewCertificateHandler(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?per_page=1000", nil)
|
||||
req = req.WithContext(contextWithRequestID())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListCertificates(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package handler
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// Mock errors for testing
|
||||
ErrMockServiceFailed = errors.New("mock service error")
|
||||
ErrMockNotFound = errors.New("mock not found error")
|
||||
ErrMockUnauthorized = errors.New("mock unauthorized error")
|
||||
ErrMockConflict = errors.New("mock conflict error")
|
||||
)
|
||||
@@ -0,0 +1,996 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/api/handler"
|
||||
"github.com/shankar0123/certctl/internal/api/router"
|
||||
"github.com/shankar0123/certctl/internal/connector/issuer/local"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
"github.com/shankar0123/certctl/internal/service"
|
||||
)
|
||||
|
||||
// TestCertificateLifecycle exercises the full certificate lifecycle:
|
||||
// create -> renew -> process jobs -> verify versions -> register agent -> heartbeat -> audit trail
|
||||
func TestCertificateLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create in-memory mock repositories
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
agentRepo := newMockAgentRepository()
|
||||
targetRepo := newMockTargetRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
policyRepo := newMockPolicyRepository()
|
||||
renewalPolicyRepo := newMockRenewalPolicyRepository()
|
||||
issuerRepo := newMockIssuerRepository()
|
||||
|
||||
// Create logger
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
// Initialize Local CA issuer connector (real implementation, no mock)
|
||||
localCA := local.New(nil, logger)
|
||||
|
||||
// Build issuer registry with adapter
|
||||
issuerRegistry := map[string]service.IssuerConnector{
|
||||
"iss-local": service.NewIssuerConnectorAdapter(localCA),
|
||||
}
|
||||
|
||||
// Initialize services (following dependency graph)
|
||||
auditService := service.NewAuditService(auditRepo)
|
||||
policyService := service.NewPolicyService(policyRepo, auditService)
|
||||
certificateService := service.NewCertificateService(certRepo, policyService, auditService)
|
||||
notificationService := service.NewNotificationService(notifRepo, make(map[string]service.Notifier))
|
||||
renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notificationService, issuerRegistry)
|
||||
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
|
||||
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||
agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
issuerService := service.NewIssuerService(issuerRepo, auditService)
|
||||
|
||||
// Initialize handlers
|
||||
certificateHandler := handler.NewCertificateHandler(certificateService)
|
||||
issuerHandler := handler.NewIssuerHandler(issuerService)
|
||||
targetHandler := handler.NewTargetHandler(&mockTargetService{targetRepo: targetRepo, auditService: auditService})
|
||||
agentHandler := handler.NewAgentHandler(agentService)
|
||||
jobHandler := handler.NewJobHandler(jobService)
|
||||
policyHandler := handler.NewPolicyHandler(policyService)
|
||||
teamHandler := handler.NewTeamHandler(&mockTeamService{})
|
||||
ownerHandler := handler.NewOwnerHandler(&mockOwnerService{})
|
||||
auditHandler := handler.NewAuditHandler(auditService)
|
||||
notificationHandler := handler.NewNotificationHandler(notificationService)
|
||||
healthHandler := handler.NewHealthHandler()
|
||||
|
||||
// Create router and register handlers
|
||||
r := router.New()
|
||||
r.RegisterHandlers(
|
||||
certificateHandler,
|
||||
issuerHandler,
|
||||
targetHandler,
|
||||
agentHandler,
|
||||
jobHandler,
|
||||
policyHandler,
|
||||
teamHandler,
|
||||
ownerHandler,
|
||||
auditHandler,
|
||||
notificationHandler,
|
||||
healthHandler,
|
||||
)
|
||||
|
||||
// Create test server
|
||||
server := httptest.NewServer(r)
|
||||
defer server.Close()
|
||||
|
||||
// ======================
|
||||
// Step 1: Check health
|
||||
// ======================
|
||||
t.Run("HealthCheck", func(t *testing.T) {
|
||||
resp, err := http.Get(server.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /health failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var body map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if body["status"] != "healthy" {
|
||||
t.Errorf("expected status=healthy, got %s", body["status"])
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 2: Create certificate
|
||||
// ======================
|
||||
var certID string
|
||||
t.Run("CreateCertificate", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
payload := map[string]interface{}{
|
||||
"name": "Example Certificate",
|
||||
"common_name": "example.com",
|
||||
"sans": []string{"www.example.com", "api.example.com"},
|
||||
"environment": "production",
|
||||
"owner_id": "owner-alice",
|
||||
"team_id": "team-platform",
|
||||
"issuer_id": "iss-local",
|
||||
"target_ids": []string{},
|
||||
"renewal_policy_id": "policy-standard",
|
||||
"status": "Pending",
|
||||
"expires_at": now.AddDate(1, 0, 0),
|
||||
"tags": map[string]string{"environment": "prod"},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
server.URL+"/api/v1/certificates",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/v1/certificates failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 201, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var cert domain.ManagedCertificate
|
||||
if err := json.NewDecoder(resp.Body).Decode(&cert); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if cert.ID == "" {
|
||||
t.Fatalf("response missing id field")
|
||||
}
|
||||
|
||||
certID = cert.ID
|
||||
t.Logf("Created certificate with ID: %s", certID)
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 3: Verify certificate
|
||||
// ======================
|
||||
t.Run("GetCertificate", func(t *testing.T) {
|
||||
resp, err := http.Get(server.URL + "/api/v1/certificates/" + certID)
|
||||
if err != nil {
|
||||
t.Fatalf("GET /api/v1/certificates/{id} failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var cert domain.ManagedCertificate
|
||||
if err := json.NewDecoder(resp.Body).Decode(&cert); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if cert.ID != certID {
|
||||
t.Errorf("expected cert ID %s, got %s", certID, cert.ID)
|
||||
}
|
||||
if cert.CommonName != "example.com" {
|
||||
t.Errorf("expected common_name example.com, got %s", cert.CommonName)
|
||||
}
|
||||
if len(cert.SANs) != 2 {
|
||||
t.Errorf("expected 2 SANs, got %d", len(cert.SANs))
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 4: Trigger renewal
|
||||
// ======================
|
||||
t.Run("TriggerRenewal", func(t *testing.T) {
|
||||
resp, err := http.Post(
|
||||
server.URL+"/api/v1/certificates/"+certID+"/renew",
|
||||
"application/json",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/v1/certificates/{id}/renew failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 202, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 5: Process jobs (simulate scheduler)
|
||||
// ======================
|
||||
t.Run("ProcessPendingJobs", func(t *testing.T) {
|
||||
// Jobs should have been created by the renewal trigger.
|
||||
// Process them using the job service directly.
|
||||
if err := jobService.ProcessPendingJobs(ctx); err != nil {
|
||||
t.Fatalf("failed to process pending jobs: %v", err)
|
||||
}
|
||||
|
||||
// Verify that jobs were processed
|
||||
jobs, err := jobRepo.ListByStatus(ctx, domain.JobStatusCompleted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list completed jobs: %v", err)
|
||||
}
|
||||
|
||||
// We expect at least one renewal job to have been processed
|
||||
if len(jobs) == 0 {
|
||||
t.Logf("Warning: no completed jobs found. This may indicate the renewal job wasn't processed.")
|
||||
// Check pending jobs instead
|
||||
pending, _ := jobRepo.ListByStatus(ctx, domain.JobStatusPending)
|
||||
t.Logf("Pending jobs: %d", len(pending))
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 6: Verify certificate versions
|
||||
// ======================
|
||||
t.Run("GetCertificateVersions", func(t *testing.T) {
|
||||
resp, err := http.Get(server.URL + "/api/v1/certificates/" + certID + "/versions")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /api/v1/certificates/{id}/versions failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Extract data field which contains the versions array
|
||||
dataField := respBody["data"]
|
||||
if dataField == nil {
|
||||
t.Logf("No versions found yet - this is expected if renewal is still in progress")
|
||||
} else {
|
||||
versions, ok := dataField.([]interface{})
|
||||
if !ok {
|
||||
t.Errorf("expected data to be array, got %T", dataField)
|
||||
} else if len(versions) > 0 {
|
||||
t.Logf("Found %d certificate versions", len(versions))
|
||||
// Verify the first version has required fields
|
||||
if version, ok := versions[0].(map[string]interface{}); ok {
|
||||
if version["pem_chain"] == nil || version["pem_chain"] == "" {
|
||||
t.Errorf("certificate version missing pem_chain")
|
||||
}
|
||||
if version["serial_number"] == nil || version["serial_number"] == "" {
|
||||
t.Errorf("certificate version missing serial_number")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 7: Register agent
|
||||
// ======================
|
||||
var agentID string
|
||||
t.Run("RegisterAgent", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"name": "agent-prod-1",
|
||||
"hostname": "prod-server-01.example.com",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
server.URL+"/api/v1/agents",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/v1/agents failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 201, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// The handler returns the agent directly, not wrapped
|
||||
var agent domain.Agent
|
||||
if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
agentID = agent.ID
|
||||
if agentID == "" {
|
||||
t.Fatalf("agent id is empty")
|
||||
}
|
||||
|
||||
t.Logf("Registered agent with ID: %s", agentID)
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 8: Agent heartbeat
|
||||
// ======================
|
||||
t.Run("AgentHeartbeat", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"agent_id": agentID,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
server.URL+"/api/v1/agents/"+agentID+"/heartbeat",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /api/v1/agents/{id}/heartbeat failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Verify agent heartbeat was updated
|
||||
agent, err := agentRepo.Get(ctx, agentID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get agent: %v", err)
|
||||
}
|
||||
|
||||
if agent.LastHeartbeatAt == nil {
|
||||
t.Errorf("agent LastHeartbeatAt was not updated")
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 9: List audit events
|
||||
// ======================
|
||||
t.Run("ListAuditEvents", func(t *testing.T) {
|
||||
resp, err := http.Get(server.URL + "/api/v1/audit?page=1&per_page=50")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /api/v1/audit failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Extract data field which contains the events array
|
||||
dataField := respBody["data"]
|
||||
if dataField == nil {
|
||||
t.Logf("No audit events found")
|
||||
} else {
|
||||
events, ok := dataField.([]interface{})
|
||||
if !ok {
|
||||
t.Errorf("expected data to be array, got %T", dataField)
|
||||
} else {
|
||||
t.Logf("Found %d audit events", len(events))
|
||||
if len(events) == 0 {
|
||||
t.Logf("Warning: no audit events found. Expected events for certificate_created, agent_registered, etc.")
|
||||
}
|
||||
|
||||
// Verify we have expected event types
|
||||
eventTypes := make(map[string]int)
|
||||
for _, evt := range events {
|
||||
if eventMap, ok := evt.(map[string]interface{}); ok {
|
||||
if action, ok := eventMap["action"].(string); ok {
|
||||
eventTypes[action]++
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Logf("Audit event types: %v", eventTypes)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Step 10: Get agent and verify status
|
||||
// ======================
|
||||
t.Run("GetAgent", func(t *testing.T) {
|
||||
resp, err := http.Get(server.URL + "/api/v1/agents/" + agentID)
|
||||
if err != nil {
|
||||
t.Fatalf("GET /api/v1/agents/{id} failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status 200, got %d. Body: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var agent domain.Agent
|
||||
if err := json.NewDecoder(resp.Body).Decode(&agent); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if agent.ID != agentID {
|
||||
t.Errorf("expected agent ID %s, got %s", agentID, agent.ID)
|
||||
}
|
||||
if agent.Status != domain.AgentStatusOnline {
|
||||
t.Errorf("expected agent status Online, got %s", agent.Status)
|
||||
}
|
||||
})
|
||||
|
||||
// ======================
|
||||
// Summary
|
||||
// ======================
|
||||
t.Run("Summary", func(t *testing.T) {
|
||||
totalCerts, _, _ := certRepo.List(ctx, &repository.CertificateFilter{})
|
||||
totalJobs, _ := jobRepo.List(ctx)
|
||||
totalAgents, _ := agentRepo.List(ctx)
|
||||
totalAuditEvents, _ := auditRepo.List(ctx, &repository.AuditFilter{})
|
||||
|
||||
t.Logf("=== Integration Test Summary ===")
|
||||
t.Logf("Certificates: %d", len(totalCerts))
|
||||
t.Logf("Jobs: %d", len(totalJobs))
|
||||
t.Logf("Agents: %d", len(totalAgents))
|
||||
t.Logf("Audit Events: %d", len(totalAuditEvents))
|
||||
|
||||
if len(totalCerts) == 0 {
|
||||
t.Error("Expected at least 1 certificate")
|
||||
}
|
||||
if len(totalAgents) == 0 {
|
||||
t.Error("Expected at least 1 agent")
|
||||
}
|
||||
if len(totalAuditEvents) == 0 {
|
||||
t.Logf("Warning: Expected audit events, but none found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Mock repository implementations for integration testing
|
||||
// These are simple in-memory implementations similar to testutil_test.go patterns
|
||||
|
||||
type mockCertificateRepository struct {
|
||||
certs map[string]*domain.ManagedCertificate
|
||||
versions map[string][]*domain.CertificateVersion
|
||||
}
|
||||
|
||||
func newMockCertificateRepository() *mockCertificateRepository {
|
||||
return &mockCertificateRepository{
|
||||
certs: make(map[string]*domain.ManagedCertificate),
|
||||
versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
||||
var certs []*domain.ManagedCertificate
|
||||
for _, c := range m.certs {
|
||||
certs = append(certs, c)
|
||||
}
|
||||
return certs, len(certs), nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||
cert, ok := m.certs[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("certificate not found")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
m.certs[cert.ID] = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
m.certs[cert.ID] = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) Archive(ctx context.Context, id string) error {
|
||||
cert, ok := m.certs[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("certificate not found")
|
||||
}
|
||||
cert.Status = domain.CertificateStatusArchived
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
||||
return m.versions[certID], nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
|
||||
m.versions[version.CertificateID] = append(m.versions[version.CertificateID], version)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
|
||||
var expiring []*domain.ManagedCertificate
|
||||
for _, c := range m.certs {
|
||||
if c.ExpiresAt.Before(before) {
|
||||
expiring = append(expiring, c)
|
||||
}
|
||||
}
|
||||
return expiring, nil
|
||||
}
|
||||
|
||||
type mockJobRepository struct {
|
||||
jobs map[string]*domain.Job
|
||||
}
|
||||
|
||||
func newMockJobRepository() *mockJobRepository {
|
||||
return &mockJobRepository{
|
||||
jobs: make(map[string]*domain.Job),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) List(ctx context.Context) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) Get(ctx context.Context, id string) (*domain.Job, error) {
|
||||
job, ok := m.jobs[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("job not found")
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) Create(ctx context.Context, job *domain.Job) error {
|
||||
m.jobs[job.ID] = job
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) Update(ctx context.Context, job *domain.Job) error {
|
||||
m.jobs[job.ID] = job
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) Delete(ctx context.Context, id string) error {
|
||||
delete(m.jobs, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
if j.Status == status {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
if j.CertificateID == certID {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
|
||||
job, ok := m.jobs[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("job not found")
|
||||
}
|
||||
job.Status = status
|
||||
if errMsg != "" {
|
||||
job.LastError = &errMsg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
if j.Type == jobType && j.Status == domain.JobStatusPending {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
type mockAuditRepository struct {
|
||||
events []*domain.AuditEvent
|
||||
}
|
||||
|
||||
func newMockAuditRepository() *mockAuditRepository {
|
||||
return &mockAuditRepository{
|
||||
events: make([]*domain.AuditEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error {
|
||||
m.events = append(m.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
|
||||
return m.events, nil
|
||||
}
|
||||
|
||||
type mockAgentRepository struct {
|
||||
agents map[string]*domain.Agent
|
||||
}
|
||||
|
||||
func newMockAgentRepository() *mockAgentRepository {
|
||||
return &mockAgentRepository{
|
||||
agents: make(map[string]*domain.Agent),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||
var agents []*domain.Agent
|
||||
for _, a := range m.agents {
|
||||
agents = append(agents, a)
|
||||
}
|
||||
return agents, nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) {
|
||||
agent, ok := m.agents[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("agent not found")
|
||||
}
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
|
||||
m.agents[agent.ID] = agent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
||||
m.agents[agent.ID] = agent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) Delete(ctx context.Context, id string) error {
|
||||
delete(m.agents, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
agent, ok := m.agents[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("agent not found")
|
||||
}
|
||||
now := time.Now()
|
||||
agent.LastHeartbeatAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
|
||||
for _, a := range m.agents {
|
||||
if a.APIKeyHash == keyHash {
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("agent not found")
|
||||
}
|
||||
|
||||
type mockTargetRepository struct {
|
||||
targets map[string]*domain.DeploymentTarget
|
||||
}
|
||||
|
||||
func newMockTargetRepository() *mockTargetRepository {
|
||||
return &mockTargetRepository{
|
||||
targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
|
||||
var targets []*domain.DeploymentTarget
|
||||
for _, t := range m.targets {
|
||||
targets = append(targets, t)
|
||||
}
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||
target, ok := m.targets[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("target not found")
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) Create(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
m.targets[target.ID] = target
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) Update(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
m.targets[target.ID] = target
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) Delete(ctx context.Context, id string) error {
|
||||
delete(m.targets, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
|
||||
return m.List(ctx)
|
||||
}
|
||||
|
||||
type mockNotificationRepository struct {
|
||||
notifications []*domain.NotificationEvent
|
||||
}
|
||||
|
||||
func newMockNotificationRepository() *mockNotificationRepository {
|
||||
return &mockNotificationRepository{
|
||||
notifications: make([]*domain.NotificationEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNotificationRepository) Create(ctx context.Context, notif *domain.NotificationEvent) error {
|
||||
m.notifications = append(m.notifications, notif)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockNotificationRepository) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
|
||||
return m.notifications, nil
|
||||
}
|
||||
|
||||
func (m *mockNotificationRepository) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
|
||||
for _, n := range m.notifications {
|
||||
if n.ID == id {
|
||||
n.Status = status
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("notification not found")
|
||||
}
|
||||
|
||||
type mockPolicyRepository struct {
|
||||
rules map[string]*domain.PolicyRule
|
||||
violations []*domain.PolicyViolation
|
||||
}
|
||||
|
||||
func newMockPolicyRepository() *mockPolicyRepository {
|
||||
return &mockPolicyRepository{
|
||||
rules: make(map[string]*domain.PolicyRule),
|
||||
violations: make([]*domain.PolicyViolation, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
|
||||
var rules []*domain.PolicyRule
|
||||
for _, r := range m.rules {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
|
||||
rule, ok := m.rules[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("rule not found")
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
m.rules[rule.ID] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
m.rules[rule.ID] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) DeleteRule(ctx context.Context, id string) error {
|
||||
delete(m.rules, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
|
||||
m.violations = append(m.violations, violation)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepository) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
|
||||
return m.violations, nil
|
||||
}
|
||||
|
||||
type mockRenewalPolicyRepository struct {
|
||||
policies map[string]*domain.RenewalPolicy
|
||||
}
|
||||
|
||||
func newMockRenewalPolicyRepository() *mockRenewalPolicyRepository {
|
||||
return &mockRenewalPolicyRepository{
|
||||
policies: make(map[string]*domain.RenewalPolicy),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockRenewalPolicyRepository) Get(ctx context.Context, id string) (*domain.RenewalPolicy, error) {
|
||||
policy, ok := m.policies[id]
|
||||
if !ok {
|
||||
// Return default policy
|
||||
return &domain.RenewalPolicy{
|
||||
ID: id,
|
||||
Name: "Default Policy",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 3600,
|
||||
AlertThresholdsDays: domain.DefaultAlertThresholds(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *mockRenewalPolicyRepository) List(ctx context.Context) ([]*domain.RenewalPolicy, error) {
|
||||
var policies []*domain.RenewalPolicy
|
||||
for _, p := range m.policies {
|
||||
policies = append(policies, p)
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
type mockIssuerRepository struct {
|
||||
issuers map[string]*domain.Issuer
|
||||
}
|
||||
|
||||
func newMockIssuerRepository() *mockIssuerRepository {
|
||||
return &mockIssuerRepository{
|
||||
issuers: make(map[string]*domain.Issuer),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
|
||||
var issuers []*domain.Issuer
|
||||
for _, i := range m.issuers {
|
||||
issuers = append(issuers, i)
|
||||
}
|
||||
return issuers, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||
issuer, ok := m.issuers[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("issuer not found")
|
||||
}
|
||||
return issuer, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
|
||||
m.issuers[issuer.ID] = issuer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
|
||||
m.issuers[issuer.ID] = issuer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Delete(ctx context.Context, id string) error {
|
||||
delete(m.issuers, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mock service implementations for handlers that need them but aren't tested
|
||||
|
||||
type mockTargetService struct {
|
||||
targetRepo *mockTargetRepository
|
||||
auditService *service.AuditService
|
||||
}
|
||||
|
||||
func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||
targets, err := m.targetRepo.List(context.Background())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var result []domain.DeploymentTarget
|
||||
for _, t := range targets {
|
||||
result = append(result, *t)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (m *mockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
|
||||
return m.targetRepo.Get(context.Background(), id)
|
||||
}
|
||||
|
||||
func (m *mockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||
if err := m.targetRepo.Create(context.Background(), &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||
target.ID = id
|
||||
if err := m.targetRepo.Update(context.Background(), &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetService) DeleteTarget(id string) error {
|
||||
return m.targetRepo.Delete(context.Background(), id)
|
||||
}
|
||||
|
||||
type mockTeamService struct{}
|
||||
|
||||
func (m *mockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
|
||||
return []domain.Team{}, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockTeamService) GetTeam(id string) (*domain.Team, error) {
|
||||
return nil, fmt.Errorf("team not found")
|
||||
}
|
||||
|
||||
func (m *mockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
||||
return &team, nil
|
||||
}
|
||||
|
||||
func (m *mockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
|
||||
team.ID = id
|
||||
return &team, nil
|
||||
}
|
||||
|
||||
func (m *mockTeamService) DeleteTeam(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockOwnerService struct{}
|
||||
|
||||
func (m *mockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
|
||||
return []domain.Owner{}, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockOwnerService) GetOwner(id string) (*domain.Owner, error) {
|
||||
return nil, fmt.Errorf("owner not found")
|
||||
}
|
||||
|
||||
func (m *mockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
||||
return &owner, nil
|
||||
}
|
||||
|
||||
func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
|
||||
owner.ID = id
|
||||
return &owner, nil
|
||||
}
|
||||
|
||||
func (m *mockOwnerService) DeleteOwner(id string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
func TestRegisterAgent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: make(map[string]*domain.Agent),
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
agent, apiKey, err := agentService.Register(ctx, "prod-agent-1", "server-01.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Register failed: %v", err)
|
||||
}
|
||||
|
||||
if agent.Name != "prod-agent-1" {
|
||||
t.Errorf("expected name prod-agent-1, got %s", agent.Name)
|
||||
}
|
||||
if agent.Hostname != "server-01.example.com" {
|
||||
t.Errorf("expected hostname server-01.example.com, got %s", agent.Hostname)
|
||||
}
|
||||
if agent.Status != domain.AgentStatusOnline {
|
||||
t.Errorf("expected status Online, got %s", agent.Status)
|
||||
}
|
||||
if apiKey == "" {
|
||||
t.Fatal("expected non-empty API key")
|
||||
}
|
||||
|
||||
if len(agentRepo.Agents) != 1 {
|
||||
t.Errorf("expected 1 agent in repo, got %d", len(agentRepo.Agents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "prod-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash123",
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
err := agentService.HeartbeatWithContext(ctx, "agent-001")
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := agentRepo.HeartbeatUpdates["agent-001"]; !ok {
|
||||
t.Fatal("heartbeat not recorded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeat_NotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: make(map[string]*domain.Agent),
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
err := agentService.HeartbeatWithContext(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPendingWork(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "prod-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash123",
|
||||
}
|
||||
|
||||
job1 := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
}
|
||||
job2 := &domain.Job{
|
||||
ID: "job-002",
|
||||
Type: domain.JobTypeRenewal,
|
||||
CertificateID: "cert-002",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
jobs, err := agentService.GetPendingWork(ctx, "agent-001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetPendingWork failed: %v", err)
|
||||
}
|
||||
|
||||
if len(jobs) != 1 {
|
||||
t.Errorf("expected 1 deployment job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Type != domain.JobTypeDeployment {
|
||||
t.Errorf("expected JobTypeDeployment, got %s", jobs[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportJobStatus(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "prod-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash123",
|
||||
}
|
||||
job := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusRunning,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
err := agentService.ReportJobStatus(ctx, "agent-001", "job-001", domain.JobStatusCompleted, "")
|
||||
if err != nil {
|
||||
t.Fatalf("ReportJobStatus failed: %v", err)
|
||||
}
|
||||
|
||||
if jobRepo.StatusUpdates["job-001"] != domain.JobStatusCompleted {
|
||||
t.Errorf("expected status Completed, got %s", jobRepo.StatusUpdates["job-001"])
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkStaleAgentsOffline(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
staleTime := now.Add(-3 * time.Hour)
|
||||
|
||||
agent1 := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "online-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash1",
|
||||
}
|
||||
agent2 := &domain.Agent{
|
||||
ID: "agent-002",
|
||||
Name: "stale-agent",
|
||||
Hostname: "server-02",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now.Add(-24 * time.Hour),
|
||||
LastHeartbeatAt: &staleTime,
|
||||
APIKeyHash: "hash2",
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent1, "agent-002": agent2},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
err := agentService.MarkStaleAgentsOffline(ctx, 1*time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkStaleAgentsOffline failed: %v", err)
|
||||
}
|
||||
|
||||
if agentRepo.Agents["agent-001"].Status != domain.AgentStatusOnline {
|
||||
t.Errorf("expected agent-001 to be Online, got %s", agentRepo.Agents["agent-001"].Status)
|
||||
}
|
||||
if agentRepo.Agents["agent-002"].Status != domain.AgentStatusOffline {
|
||||
t.Errorf("expected agent-002 to be Offline, got %s", agentRepo.Agents["agent-002"].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmitCSR(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "prod-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash123",
|
||||
}
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-local",
|
||||
Status: domain.CertificateStatusPending,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
issuerConnector := &mockIssuerConnector{
|
||||
Result: &IssuanceResult{
|
||||
Serial: "serial-123",
|
||||
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
|
||||
NotBefore: now,
|
||||
NotAfter: now.AddDate(1, 0, 0),
|
||||
},
|
||||
}
|
||||
issuerRegistry := map[string]IssuerConnector{"iss-local": issuerConnector}
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
csrPEM := "-----BEGIN CERTIFICATE REQUEST-----\ntest-csr\n-----END CERTIFICATE REQUEST-----"
|
||||
err := agentService.SubmitCSR(ctx, "agent-001", "cert-001", []byte(csrPEM))
|
||||
if err != nil {
|
||||
t.Fatalf("SubmitCSR failed: %v", err)
|
||||
}
|
||||
|
||||
if len(certRepo.Versions["cert-001"]) != 1 {
|
||||
t.Errorf("expected 1 certificate version, got %d", len(certRepo.Versions["cert-001"]))
|
||||
}
|
||||
|
||||
if cert.Status != domain.CertificateStatusActive {
|
||||
t.Errorf("expected certificate status Active, got %s", cert.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubmitCSR_EmptyCSR(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
agent := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "prod-agent",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash123",
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
err := agentService.SubmitCSR(ctx, "agent-001", "", []byte{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty CSR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAgents(t *testing.T) {
|
||||
now := time.Now()
|
||||
agent1 := &domain.Agent{
|
||||
ID: "agent-001",
|
||||
Name: "agent1",
|
||||
Hostname: "server-01",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash1",
|
||||
}
|
||||
agent2 := &domain.Agent{
|
||||
ID: "agent-002",
|
||||
Name: "agent2",
|
||||
Hostname: "server-02",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
LastHeartbeatAt: &now,
|
||||
APIKeyHash: "hash2",
|
||||
}
|
||||
|
||||
agentRepo := &mockAgentRepo{
|
||||
Agents: map[string]*domain.Agent{"agent-001": agent1, "agent-002": agent2},
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
targetRepo := &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
issuerRegistry := make(map[string]IssuerConnector)
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry)
|
||||
|
||||
agents, total, err := agentService.ListAgents(1, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAgents failed: %v", err)
|
||||
}
|
||||
|
||||
if len(agents) != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", len(agents))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
func TestRecordEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
err := service.RecordEvent(ctx, "user123", domain.ActorTypeUser, "certificate_created", "certificate", "cert-001", map[string]interface{}{"common_name": "example.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("RecordEvent failed: %v", err)
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
|
||||
event := auditRepo.Events[0]
|
||||
if event.Actor != "user123" {
|
||||
t.Errorf("expected actor user123, got %s", event.Actor)
|
||||
}
|
||||
if event.ActorType != domain.ActorTypeUser {
|
||||
t.Errorf("expected actor type User, got %s", event.ActorType)
|
||||
}
|
||||
if event.Action != "certificate_created" {
|
||||
t.Errorf("expected action certificate_created, got %s", event.Action)
|
||||
}
|
||||
if event.ResourceType != "certificate" {
|
||||
t.Errorf("expected resource type certificate, got %s", event.ResourceType)
|
||||
}
|
||||
if event.ResourceID != "cert-001" {
|
||||
t.Errorf("expected resource ID cert-001, got %s", event.ResourceID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordEvent_RepoError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
CreateErr: errNotFound,
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
err := service.RecordEvent(ctx, "user123", domain.ActorTypeUser, "test_action", "resource", "res-001", map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByResource(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
event1 := &domain.AuditEvent{
|
||||
ID: "audit-1",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
event2 := &domain.AuditEvent{
|
||||
ID: "audit-2",
|
||||
Actor: "user2",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "updated",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
event3 := &domain.AuditEvent{
|
||||
ID: "audit-3",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-002",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
auditRepo.AddEvent(event1)
|
||||
auditRepo.AddEvent(event2)
|
||||
auditRepo.AddEvent(event3)
|
||||
|
||||
events, err := service.ListByResource(ctx, "certificate", "cert-001")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByResource failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Errorf("expected 2 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByActor(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
event1 := &domain.AuditEvent{
|
||||
ID: "audit-1",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
event2 := &domain.AuditEvent{
|
||||
ID: "audit-2",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "updated",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-002",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
event3 := &domain.AuditEvent{
|
||||
ID: "audit-3",
|
||||
Actor: "user2",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-003",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
auditRepo.AddEvent(event1)
|
||||
auditRepo.AddEvent(event2)
|
||||
auditRepo.AddEvent(event3)
|
||||
|
||||
events, err := service.ListByActor(ctx, "user1")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByActor failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Errorf("expected 2 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByAction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
now := time.Now()
|
||||
from := now.Add(-1 * time.Hour)
|
||||
to := now.Add(1 * time.Hour)
|
||||
|
||||
event1 := &domain.AuditEvent{
|
||||
ID: "audit-1",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "certificate_created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: now.Add(-30 * time.Minute),
|
||||
}
|
||||
event2 := &domain.AuditEvent{
|
||||
ID: "audit-2",
|
||||
Actor: "user2",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "certificate_created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-002",
|
||||
Timestamp: now.Add(-20 * time.Minute),
|
||||
}
|
||||
event3 := &domain.AuditEvent{
|
||||
ID: "audit-3",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "certificate_updated",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: now.Add(-10 * time.Minute),
|
||||
}
|
||||
|
||||
auditRepo.AddEvent(event1)
|
||||
auditRepo.AddEvent(event2)
|
||||
auditRepo.AddEvent(event3)
|
||||
|
||||
events, err := service.ListByAction(ctx, "certificate_created", from, to)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByAction failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Errorf("expected 2 events, got %d", len(events))
|
||||
}
|
||||
|
||||
for _, e := range events {
|
||||
if e.Action != "certificate_created" {
|
||||
t.Errorf("expected action certificate_created, got %s", e.Action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListByAction_EmptyRange(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
now := time.Now()
|
||||
from := now.Add(1 * time.Hour)
|
||||
to := now.Add(2 * time.Hour)
|
||||
|
||||
event := &domain.AuditEvent{
|
||||
ID: "audit-1",
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "certificate_created",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: now.Add(-30 * time.Minute),
|
||||
}
|
||||
auditRepo.AddEvent(event)
|
||||
|
||||
events, err := service.ListByAction(ctx, "certificate_created", from, to)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByAction failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 0 {
|
||||
t.Errorf("expected 0 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordEvent_ComplexDetails(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
details := map[string]interface{}{
|
||||
"common_name": "example.com",
|
||||
"sans": []string{"www.example.com", "api.example.com"},
|
||||
"issuer_id": "iss-123",
|
||||
"count": 5,
|
||||
}
|
||||
|
||||
err := service.RecordEvent(ctx, "user1", domain.ActorTypeUser, "certificate_created", "certificate", "cert-001", details)
|
||||
if err != nil {
|
||||
t.Fatalf("RecordEvent failed: %v", err)
|
||||
}
|
||||
|
||||
event := auditRepo.Events[0]
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(event.Details, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal details: %v", err)
|
||||
}
|
||||
|
||||
if decoded["common_name"] != "example.com" {
|
||||
t.Errorf("expected common_name example.com, got %v", decoded["common_name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
event := &domain.AuditEvent{
|
||||
ID: "audit-" + string(rune(i)),
|
||||
Actor: "user1",
|
||||
ActorType: domain.ActorTypeUser,
|
||||
Action: "test",
|
||||
ResourceType: "certificate",
|
||||
ResourceID: "cert-001",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
auditRepo.AddEvent(event)
|
||||
}
|
||||
|
||||
filter := &repository.AuditFilter{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
}
|
||||
|
||||
events, err := service.List(ctx, filter)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 5 {
|
||||
t.Errorf("expected 5 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestList_RepoError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auditRepo := &mockAuditRepo{
|
||||
ListErr: errNotFound,
|
||||
}
|
||||
service := NewAuditService(auditRepo)
|
||||
|
||||
filter := &repository.AuditFilter{}
|
||||
|
||||
_, err := service.List(ctx, filter)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
func TestCreateCertificate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{
|
||||
Events: []*domain.AuditEvent{},
|
||||
}
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: make(map[string]*domain.PolicyRule),
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
Name: "api-prod",
|
||||
CommonName: "api.example.com",
|
||||
SANs: []string{"api.example.com"},
|
||||
Environment: "production",
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-acme",
|
||||
TargetIDs: []string{"target-1"},
|
||||
RenewalPolicyID: "policy-1",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
Tags: map[string]string{"env": "prod"},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
err := certService.Create(ctx, cert, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
if len(certRepo.Certs) != 1 {
|
||||
t.Errorf("expected 1 cert, got %d", len(certRepo.Certs))
|
||||
}
|
||||
|
||||
storedCert, ok := certRepo.Certs["cert-001"]
|
||||
if !ok {
|
||||
t.Fatal("certificate not stored")
|
||||
}
|
||||
if storedCert.CommonName != "api.example.com" {
|
||||
t.Errorf("expected common name api.example.com, got %s", storedCert.CommonName)
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCertificate_MissingRequired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
// Missing CommonName and IssuerID
|
||||
}
|
||||
|
||||
err := certService.Create(ctx, cert, "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing required fields")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-1",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
retrieved, err := certService.Get(ctx, "cert-001")
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.CommonName != "example.com" {
|
||||
t.Errorf("expected common name example.com, got %s", retrieved.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificate_NotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
_, err := certService.Get(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent certificate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateCertificate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
originalCert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-1",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": originalCert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
updatedCert := *originalCert
|
||||
updatedCert.Status = domain.CertificateStatusExpiring
|
||||
updatedCert.ExpiresAt = now.AddDate(0, 0, 5)
|
||||
|
||||
err := certService.Update(ctx, &updatedCert, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
stored := certRepo.Certs["cert-001"]
|
||||
if stored.Status != domain.CertificateStatusExpiring {
|
||||
t.Errorf("expected status Expiring, got %s", stored.Status)
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchiveCertificate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-1",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
err := certService.Archive(ctx, "cert-001", "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Archive failed: %v", err)
|
||||
}
|
||||
|
||||
archived := certRepo.Certs["cert-001"]
|
||||
if archived.Status != domain.CertificateStatusArchived {
|
||||
t.Errorf("expected status Archived, got %s", archived.Status)
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetVersions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
version1 := &domain.CertificateVersion{
|
||||
ID: "ver-1",
|
||||
CertificateID: "cert-001",
|
||||
SerialNumber: "serial-1",
|
||||
NotBefore: now.AddDate(-1, 0, 0),
|
||||
NotAfter: now,
|
||||
PEMChain: "cert1-pem",
|
||||
CreatedAt: now.AddDate(-1, 0, 0),
|
||||
}
|
||||
version2 := &domain.CertificateVersion{
|
||||
ID: "ver-2",
|
||||
CertificateID: "cert-001",
|
||||
SerialNumber: "serial-2",
|
||||
NotBefore: now,
|
||||
NotAfter: now.AddDate(1, 0, 0),
|
||||
PEMChain: "cert2-pem",
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: map[string][]*domain.CertificateVersion{"cert-001": {version1, version2}},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
versions, err := certService.GetVersions(ctx, "cert-001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetVersions failed: %v", err)
|
||||
}
|
||||
|
||||
if len(versions) != 2 {
|
||||
t.Errorf("expected 2 versions, got %d", len(versions))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerRenewal(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-1",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(0, 0, 5),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("TriggerRenewal failed: %v", err)
|
||||
}
|
||||
|
||||
renewed := certRepo.Certs["cert-001"]
|
||||
if renewed.Status != domain.CertificateStatusRenewalInProgress {
|
||||
t.Errorf("expected status RenewalInProgress, got %s", renewed.Status)
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerRenewal_Archived(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-1",
|
||||
Status: domain.CertificateStatusArchived,
|
||||
ExpiresAt: now.AddDate(0, 0, 5),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for archived certificate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCertificates(t *testing.T) {
|
||||
now := time.Now()
|
||||
cert1 := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "api.example.com",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
cert2 := &domain.ManagedCertificate{
|
||||
ID: "cert-002",
|
||||
CommonName: "web.example.com",
|
||||
Status: domain.CertificateStatusExpiring,
|
||||
ExpiresAt: now.AddDate(0, 0, 5),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: map[string]*domain.ManagedCertificate{"cert-001": cert1, "cert-002": cert2},
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
policyRepo := &mockPolicyRepo{Rules: make(map[string]*domain.PolicyRule)}
|
||||
|
||||
policyService := NewPolicyService(policyRepo, NewAuditService(auditRepo))
|
||||
auditService := NewAuditService(auditRepo)
|
||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||
|
||||
certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCertificates failed: %v", err)
|
||||
}
|
||||
|
||||
if len(certs) != 2 {
|
||||
t.Errorf("expected 2 certs, got %d", len(certs))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// helper to build job service with proper constructor signatures
|
||||
func newTestJobService(jobRepo *mockJobRepo) *JobService {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
|
||||
|
||||
certRepo := &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
renewalPolicyRepo := &mockRenewalPolicyRepo{
|
||||
Policies: make(map[string]*domain.RenewalPolicy),
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifService := NewNotificationService(notifRepo, make(map[string]Notifier))
|
||||
targetRepo := &mockTargetRepo{Targets: make(map[string]*domain.DeploymentTarget)}
|
||||
agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent)}
|
||||
|
||||
renewalService := NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, auditService, notifService, make(map[string]IssuerConnector))
|
||||
deploymentService := NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notifService)
|
||||
|
||||
return NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||
}
|
||||
|
||||
func TestProcessPendingJobs_Renewal(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
job := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeRenewal,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusPending,
|
||||
Attempts: 0,
|
||||
MaxAttempts: 3,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
err := jobService.ProcessPendingJobs(ctx)
|
||||
if err != nil {
|
||||
t.Logf("ProcessPendingJobs returned error (expected for renewal without cert): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessPendingJobs_NoJobs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
err := jobService.ProcessPendingJobs(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessPendingJobs failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelJob(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
job := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
err := jobService.CancelJobWithContext(ctx, "job-001")
|
||||
if err != nil {
|
||||
t.Fatalf("CancelJob failed: %v", err)
|
||||
}
|
||||
|
||||
if jobRepo.StatusUpdates["job-001"] != domain.JobStatusCancelled {
|
||||
t.Errorf("expected status Cancelled, got %s", jobRepo.StatusUpdates["job-001"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelJob_AlreadyCompleted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
job := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusCompleted,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
err := jobService.CancelJobWithContext(ctx, "job-001")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for completed job")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJob(t *testing.T) {
|
||||
now := time.Now()
|
||||
job := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
retrieved, err := jobService.GetJob("job-001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetJob failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID != "job-001" {
|
||||
t.Errorf("expected job ID job-001, got %s", retrieved.ID)
|
||||
}
|
||||
if retrieved.Type != domain.JobTypeDeployment {
|
||||
t.Errorf("expected job type Deployment, got %s", retrieved.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListJobs(t *testing.T) {
|
||||
now := time.Now()
|
||||
job1 := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusCompleted,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
job2 := &domain.Job{
|
||||
ID: "job-002",
|
||||
Type: domain.JobTypeRenewal,
|
||||
CertificateID: "cert-002",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
jobs, total, err := jobService.ListJobs("", "", 1, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListJobs failed: %v", err)
|
||||
}
|
||||
|
||||
if len(jobs) != 2 {
|
||||
t.Errorf("expected 2 jobs, got %d", len(jobs))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListJobs_FilterByStatus(t *testing.T) {
|
||||
now := time.Now()
|
||||
job1 := &domain.Job{
|
||||
ID: "job-001",
|
||||
Type: domain.JobTypeDeployment,
|
||||
CertificateID: "cert-001",
|
||||
Status: domain.JobStatusCompleted,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
job2 := &domain.Job{
|
||||
ID: "job-002",
|
||||
Type: domain.JobTypeRenewal,
|
||||
CertificateID: "cert-002",
|
||||
Status: domain.JobStatusPending,
|
||||
CreatedAt: now,
|
||||
ScheduledAt: now,
|
||||
}
|
||||
|
||||
jobRepo := &mockJobRepo{
|
||||
Jobs: map[string]*domain.Job{"job-001": job1, "job-002": job2},
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
|
||||
jobService := newTestJobService(jobRepo)
|
||||
|
||||
jobs, total, err := jobService.ListJobs(string(domain.JobStatusPending), "", 1, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListJobs failed: %v", err)
|
||||
}
|
||||
|
||||
if len(jobs) != 1 {
|
||||
t.Errorf("expected 1 pending job, got %d", len(jobs))
|
||||
}
|
||||
if total != 1 {
|
||||
t.Errorf("expected total 1, got %d", total)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,567 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
func TestSendThresholdAlert(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-test-1",
|
||||
CommonName: "example.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 5),
|
||||
}
|
||||
|
||||
threshold := 7
|
||||
daysUntilExpiry := 5
|
||||
|
||||
err := svc.SendThresholdAlert(ctx, cert, daysUntilExpiry, threshold)
|
||||
if err != nil {
|
||||
t.Fatalf("SendThresholdAlert failed: %v", err)
|
||||
}
|
||||
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
notif := notifRepo.Notifications[0]
|
||||
if notif.Type != domain.NotificationTypeExpirationWarning {
|
||||
t.Errorf("expected ExpirationWarning, got %s", notif.Type)
|
||||
}
|
||||
|
||||
// Verify message contains threshold tag
|
||||
if !strings.Contains(notif.Message, "[threshold:7]") {
|
||||
t.Errorf("expected threshold tag in message, got: %s", notif.Message)
|
||||
}
|
||||
|
||||
// Verify notifier was called
|
||||
if notifier.getSentCount() != 1 {
|
||||
t.Errorf("expected 1 sent message, got %d", notifier.getSentCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendThresholdAlert_Expired(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-test-expired",
|
||||
CommonName: "expired.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(0, 0, -1),
|
||||
}
|
||||
|
||||
threshold := 0
|
||||
daysUntilExpiry := -1
|
||||
|
||||
err := svc.SendThresholdAlert(ctx, cert, daysUntilExpiry, threshold)
|
||||
if err != nil {
|
||||
t.Fatalf("SendThresholdAlert failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify message contains [EXPIRED] prefix
|
||||
if len(notifRepo.Notifications) > 0 && !strings.Contains(notifRepo.Notifications[0].Message, "[EXPIRED]") {
|
||||
t.Errorf("expected [EXPIRED] in message, got: %s", notifRepo.Notifications[0].Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasThresholdNotification_Found(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add an existing notification with threshold tag
|
||||
existingNotif := &domain.NotificationEvent{
|
||||
ID: "notif-1",
|
||||
CertificateID: stringPtr("mc-test-1"),
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: "Certificate expires soon\n\n[threshold:30]",
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(existingNotif)
|
||||
|
||||
// Check for existing notification
|
||||
found, err := svc.HasThresholdNotification(ctx, "mc-test-1", 30)
|
||||
if err != nil {
|
||||
t.Fatalf("HasThresholdNotification failed: %v", err)
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("expected to find threshold notification, but didn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasThresholdNotification_NotFound(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Check for non-existent notification
|
||||
found, err := svc.HasThresholdNotification(ctx, "mc-test-1", 30)
|
||||
if err != nil {
|
||||
t.Fatalf("HasThresholdNotification failed: %v", err)
|
||||
}
|
||||
|
||||
if found {
|
||||
t.Errorf("expected not to find threshold notification, but did")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendExpirationWarning(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-test-warning",
|
||||
CommonName: "warn.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 10),
|
||||
}
|
||||
|
||||
err := svc.SendExpirationWarning(ctx, cert, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("SendExpirationWarning failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify notification was created
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
if notifRepo.Notifications[0].Type != domain.NotificationTypeExpirationWarning {
|
||||
t.Errorf("expected ExpirationWarning type, got %s", notifRepo.Notifications[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendRenewalNotification_Success(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-renewed",
|
||||
CommonName: "renewed.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
|
||||
err := svc.SendRenewalNotification(ctx, cert, true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SendRenewalNotification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify notification was created with success type
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
if notifRepo.Notifications[0].Type != domain.NotificationTypeRenewalSuccess {
|
||||
t.Errorf("expected RenewalSuccess type, got %s", notifRepo.Notifications[0].Type)
|
||||
}
|
||||
|
||||
// Verify message contains success text
|
||||
if !strings.Contains(notifRepo.Notifications[0].Message, "successfully renewed") {
|
||||
t.Errorf("expected success message, got: %s", notifRepo.Notifications[0].Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendRenewalNotification_Failure(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-failed-renewal",
|
||||
CommonName: "failed.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 5),
|
||||
}
|
||||
|
||||
testErr := fmt.Errorf("issuer unavailable")
|
||||
err := svc.SendRenewalNotification(ctx, cert, false, testErr)
|
||||
if err != nil {
|
||||
t.Fatalf("SendRenewalNotification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify notification was created with failure type
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
if notifRepo.Notifications[0].Type != domain.NotificationTypeRenewalFailure {
|
||||
t.Errorf("expected RenewalFailure type, got %s", notifRepo.Notifications[0].Type)
|
||||
}
|
||||
|
||||
// Verify message contains error info
|
||||
if !strings.Contains(notifRepo.Notifications[0].Message, "failed to renew") {
|
||||
t.Errorf("expected failure message, got: %s", notifRepo.Notifications[0].Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessPendingNotifications(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add pending notifications
|
||||
for i := 0; i < 3; i++ {
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: fmt.Sprintf("notif-%d", i),
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: fmt.Sprintf("Test notification %d", i),
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
}
|
||||
|
||||
err := svc.ProcessPendingNotifications(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessPendingNotifications failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify all notifications were sent
|
||||
if notifier.getSentCount() != 3 {
|
||||
t.Errorf("expected 3 sent notifications, got %d", notifier.getSentCount())
|
||||
}
|
||||
|
||||
// Verify status was updated to sent
|
||||
for _, notif := range notifRepo.Notifications {
|
||||
if notif.Status != "sent" {
|
||||
t.Errorf("expected notification status 'sent', got %s", notif.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessPendingNotifications_NoNotifier(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
// No notifier registered - demo mode
|
||||
registry := map[string]Notifier{}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add pending notification
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: "notif-demo",
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail, // Channel not in registry
|
||||
Recipient: "owner-1",
|
||||
Message: "Test notification",
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
|
||||
// Should not fail, just mark as sent (demo mode graceful skip)
|
||||
err := svc.ProcessPendingNotifications(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessPendingNotifications should not fail in demo mode: %v", err)
|
||||
}
|
||||
|
||||
// Status should still be updated to sent
|
||||
if len(notifRepo.Notifications) > 0 && notifRepo.Notifications[0].Status == "sent" {
|
||||
// This is fine - graceful skip marks as sent
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterNotifier(t *testing.T) {
|
||||
t.Helper()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
notifier := newMockNotifier()
|
||||
svc.RegisterNotifier("Email", notifier)
|
||||
|
||||
// Verify notifier was registered
|
||||
if svc.notifierRegistry["Email"] == nil {
|
||||
t.Errorf("expected notifier to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListNotifications(t *testing.T) {
|
||||
t.Helper()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add test notifications
|
||||
for i := 0; i < 5; i++ {
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: fmt.Sprintf("notif-list-%d", i),
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: fmt.Sprintf("owner-%d", i%2),
|
||||
Message: fmt.Sprintf("Test notification %d", i),
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
}
|
||||
|
||||
// List with pagination
|
||||
notifs, total, err := svc.ListNotifications(1, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("ListNotifications failed: %v", err)
|
||||
}
|
||||
|
||||
if len(notifs) == 0 {
|
||||
t.Errorf("expected notifications, got none")
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
t.Errorf("expected total count > 0, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkAsRead(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add a notification
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: "notif-read",
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: "Test notification",
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
|
||||
// Mark as read
|
||||
err := svc.MarkAsRead(notif.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkAsRead failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify status was updated
|
||||
if len(notifRepo.Notifications) > 0 && notifRepo.Notifications[0].Status != "read" {
|
||||
t.Errorf("expected status 'read', got %s", notifRepo.Notifications[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNotification(t *testing.T) {
|
||||
t.Helper()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
// Add a notification
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: "notif-get-test",
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: "Test notification",
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
|
||||
// Get the notification
|
||||
retrieved, err := svc.GetNotification(notif.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetNotification failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved == nil {
|
||||
t.Errorf("expected notification, got nil")
|
||||
} else if retrieved.ID != notif.ID {
|
||||
t.Errorf("expected ID %s, got %s", notif.ID, retrieved.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendDeploymentNotification_Success(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-deploy",
|
||||
CommonName: "deploy.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
|
||||
target := &domain.DeploymentTarget{
|
||||
ID: "target-1",
|
||||
Name: "NGINX-Prod",
|
||||
}
|
||||
|
||||
err := svc.SendDeploymentNotification(ctx, cert, target, true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SendDeploymentNotification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify notification was created
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
if notifRepo.Notifications[0].Type != domain.NotificationTypeDeploymentSuccess {
|
||||
t.Errorf("expected DeploymentSuccess type, got %s", notifRepo.Notifications[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendDeploymentNotification_Failure(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
registry := map[string]Notifier{
|
||||
"Email": notifier,
|
||||
}
|
||||
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-deploy-fail",
|
||||
CommonName: "deploy-fail.com",
|
||||
OwnerID: "owner-1",
|
||||
ExpiresAt: time.Now().AddDate(1, 0, 0),
|
||||
}
|
||||
|
||||
target := &domain.DeploymentTarget{
|
||||
ID: "target-2",
|
||||
Name: "NGINX-Staging",
|
||||
}
|
||||
|
||||
deployErr := fmt.Errorf("connection timeout")
|
||||
err := svc.SendDeploymentNotification(ctx, cert, target, false, deployErr)
|
||||
if err != nil {
|
||||
t.Fatalf("SendDeploymentNotification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify notification was created
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
if notifRepo.Notifications[0].Type != domain.NotificationTypeDeploymentFailure {
|
||||
t.Errorf("expected DeploymentFailure type, got %s", notifRepo.Notifications[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNotificationHistory(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
notifRepo := newMockNotificationRepository()
|
||||
registry := map[string]Notifier{}
|
||||
svc := NewNotificationService(notifRepo, registry)
|
||||
|
||||
certID := "mc-history"
|
||||
|
||||
// Add multiple notifications for same cert
|
||||
for i := 0; i < 3; i++ {
|
||||
notif := &domain.NotificationEvent{
|
||||
ID: fmt.Sprintf("notif-hist-%d", i),
|
||||
CertificateID: &certID,
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: fmt.Sprintf("Alert %d", i),
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(notif)
|
||||
}
|
||||
|
||||
// Get history
|
||||
history, err := svc.GetNotificationHistory(ctx, certID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetNotificationHistory failed: %v", err)
|
||||
}
|
||||
|
||||
if len(history) < 1 {
|
||||
t.Errorf("expected at least 1 notification, got %d", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
func TestCreateRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: make(map[string]*domain.PolicyRule),
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
config := map[string]interface{}{"issuers": []string{"iss-acme"}}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
|
||||
rule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Config: configJSON,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := policyService.CreateRule(ctx, rule, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateRule failed: %v", err)
|
||||
}
|
||||
|
||||
if len(policyRepo.Rules) != 1 {
|
||||
t.Errorf("expected 1 rule, got %d", len(policyRepo.Rules))
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
rule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
retrieved, err := policyService.GetRule(ctx, "rule-001")
|
||||
if err != nil {
|
||||
t.Fatalf("GetRule failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != "Allowed Issuers" {
|
||||
t.Errorf("expected name Allowed Issuers, got %s", retrieved.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRule_NotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: make(map[string]*domain.PolicyRule),
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
_, err := policyService.GetRule(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRules(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
rule1 := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
rule2 := &domain.PolicyRule{
|
||||
ID: "rule-002",
|
||||
Name: "Required Metadata",
|
||||
Type: domain.PolicyTypeRequiredMetadata,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
rules, err := policyService.ListRules(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListRules failed: %v", err)
|
||||
}
|
||||
|
||||
if len(rules) != 2 {
|
||||
t.Errorf("expected 2 rules, got %d", len(rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
originalRule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": originalRule},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
updatedRule := *originalRule
|
||||
updatedRule.Enabled = false
|
||||
|
||||
err := policyService.UpdateRule(ctx, &updatedRule, "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateRule failed: %v", err)
|
||||
}
|
||||
|
||||
stored := policyRepo.Rules["rule-001"]
|
||||
if stored.Enabled {
|
||||
t.Error("expected rule to be disabled")
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
rule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{Events: []*domain.AuditEvent{}}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
err := policyService.DeleteRule(ctx, "rule-001", "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteRule failed: %v", err)
|
||||
}
|
||||
|
||||
if len(policyRepo.Rules) != 0 {
|
||||
t.Errorf("expected 0 rules, got %d", len(policyRepo.Rules))
|
||||
}
|
||||
|
||||
if len(auditRepo.Events) != 1 {
|
||||
t.Errorf("expected 1 audit event, got %d", len(auditRepo.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCertificate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
rule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "iss-acme",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
violations, err := policyService.ValidateCertificate(ctx, cert)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
if len(violations) > 0 {
|
||||
t.Errorf("expected no violations, got %d", len(violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCertificate_WithViolation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
rule := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "", // Missing issuer
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
violations, err := policyService.ValidateCertificate(ctx, cert)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
if len(violations) != 1 {
|
||||
t.Errorf("expected 1 violation, got %d", len(violations))
|
||||
}
|
||||
|
||||
if violations[0].CertificateID != "cert-001" {
|
||||
t.Errorf("expected violation for cert-001, got %s", violations[0].CertificateID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCertificate_MultipleViolations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
rule1 := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Allowed Issuers",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
rule2 := &domain.PolicyRule{
|
||||
ID: "rule-002",
|
||||
Name: "Required Metadata",
|
||||
Type: domain.PolicyTypeRequiredMetadata,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "cert-001",
|
||||
CommonName: "example.com",
|
||||
IssuerID: "", // Missing issuer
|
||||
Tags: nil, // Missing metadata
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.AddDate(1, 0, 0),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
violations, err := policyService.ValidateCertificate(ctx, cert)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
if len(violations) != 2 {
|
||||
t.Errorf("expected 2 violations, got %d", len(violations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPolicies(t *testing.T) {
|
||||
now := time.Now()
|
||||
rule1 := &domain.PolicyRule{
|
||||
ID: "rule-001",
|
||||
Name: "Rule 1",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
rule2 := &domain.PolicyRule{
|
||||
ID: "rule-002",
|
||||
Name: "Rule 2",
|
||||
Type: domain.PolicyTypeRequiredMetadata,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: map[string]*domain.PolicyRule{"rule-001": rule1, "rule-002": rule2},
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
policies, total, err := policyService.ListPolicies(1, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicies failed: %v", err)
|
||||
}
|
||||
|
||||
if len(policies) != 2 {
|
||||
t.Errorf("expected 2 policies, got %d", len(policies))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePolicy(t *testing.T) {
|
||||
now := time.Now()
|
||||
policyRepo := &mockPolicyRepo{
|
||||
Rules: make(map[string]*domain.PolicyRule),
|
||||
Violations: []*domain.PolicyViolation{},
|
||||
}
|
||||
auditRepo := &mockAuditRepo{}
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
policyService := NewPolicyService(policyRepo, auditService)
|
||||
|
||||
policy := domain.PolicyRule{
|
||||
Name: "Test Policy",
|
||||
Type: domain.PolicyTypeAllowedIssuers,
|
||||
Enabled: true,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
created, err := policyService.CreatePolicy(policy)
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicy failed: %v", err)
|
||||
}
|
||||
|
||||
if created.ID == "" {
|
||||
t.Fatal("expected non-empty policy ID")
|
||||
}
|
||||
|
||||
if len(policyRepo.Rules) != 1 {
|
||||
t.Errorf("expected 1 rule in repo, got %d", len(policyRepo.Rules))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,866 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
func TestCheckExpiringCertificates_SendsThresholdAlerts(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
|
||||
"Email": notifier,
|
||||
})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create a cert expiring in 10 days
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-expiring",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 10),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy with thresholds
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run expiry check
|
||||
err := svc.CheckExpiringCertificates(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("CheckExpiringCertificates failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify alerts were sent
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected at least 1 alert, got %d", len(notifRepo.Notifications))
|
||||
}
|
||||
|
||||
// Verify renewal job was created
|
||||
if len(jobRepo.Jobs) < 1 {
|
||||
t.Errorf("expected renewal job to be created")
|
||||
}
|
||||
|
||||
hasRenewalJob := false
|
||||
for _, job := range jobRepo.Jobs {
|
||||
if job.Type == domain.JobTypeRenewal {
|
||||
hasRenewalJob = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasRenewalJob {
|
||||
t.Errorf("expected renewal job in jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_DeduplicatesAlerts(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
notifier := newMockNotifier()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
|
||||
"Email": notifier,
|
||||
})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create cert
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-dedup",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 10),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Add existing threshold alert notification
|
||||
existingNotif := &domain.NotificationEvent{
|
||||
ID: "notif-existing",
|
||||
CertificateID: &cert.ID,
|
||||
Type: domain.NotificationTypeExpirationWarning,
|
||||
Channel: domain.NotificationChannelEmail,
|
||||
Recipient: "owner-1",
|
||||
Message: "Alert [threshold:7]",
|
||||
Status: "sent",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
notifRepo.AddNotification(existingNotif)
|
||||
|
||||
// Run first check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
initialCount := notifier.getSentCount()
|
||||
|
||||
// Run second check - should deduplicate
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
finalCount := notifier.getSentCount()
|
||||
|
||||
// Should not send duplicate alerts
|
||||
if finalCount > initialCount {
|
||||
t.Errorf("expected deduplication, but sent new alerts: initial=%d, final=%d", initialCount, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_SkipsRenewalInProgress(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create cert with RenewalInProgress status
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-in-progress",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusRenewalInProgress,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 10),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run check
|
||||
err := svc.CheckExpiringCertificates(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("CheckExpiringCertificates failed: %v", err)
|
||||
}
|
||||
|
||||
// Should not create renewal job for cert already renewing
|
||||
for _, job := range jobRepo.Jobs {
|
||||
if job.Type == domain.JobTypeRenewal {
|
||||
t.Errorf("should not create renewal job for cert with RenewalInProgress status")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_UpdatesStatusToExpiring(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create active cert that will become expiring
|
||||
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-expiring-status",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-unregistered",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 5), // 5 days, within 30-day threshold
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy with AutoRenew: false so we only test status transition
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: false,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Verify status was updated to Expiring
|
||||
updated, _ := certRepo.Get(ctx, cert.ID)
|
||||
if updated.Status != domain.CertificateStatusExpiring {
|
||||
t.Errorf("expected status Expiring, got %s", updated.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_UpdatesStatusToExpired(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create cert that is already expired
|
||||
// Use an issuer NOT in the registry so no renewal job is created (which would override status)
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-expired-status",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-unregistered",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, -1), // Already expired
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy with AutoRenew: false so we only test status transition
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: false,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Verify status was updated to Expired
|
||||
updated, _ := certRepo.Get(ctx, cert.ID)
|
||||
if updated.Status != domain.CertificateStatusExpired {
|
||||
t.Errorf("expected status Expired, got %s", updated.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_CreatesRenewalJob(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create expiring cert with registered issuer
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-job-create",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test", // Registered issuer
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 20),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Verify renewal job was created
|
||||
hasRenewalJob := false
|
||||
for _, job := range jobRepo.Jobs {
|
||||
if job.Type == domain.JobTypeRenewal && job.Status == domain.JobStatusPending {
|
||||
hasRenewalJob = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasRenewalJob {
|
||||
t.Errorf("expected renewal job to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_SkipsWithoutIssuer(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
// Empty issuer registry
|
||||
issuerRegistry := map[string]IssuerConnector{}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create cert with unregistered issuer
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-no-issuer",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-missing", // Not in registry
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 20),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Run check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Should not create renewal job without issuer
|
||||
for _, job := range jobRepo.Jobs {
|
||||
if job.Type == domain.JobTypeRenewal {
|
||||
t.Errorf("should not create renewal job for cert with missing issuer")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiringCertificates_SkipsDuplicateJobs(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create cert
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-dup-job",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 20),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create policy
|
||||
policy := &domain.RenewalPolicy{
|
||||
ID: "rp-standard",
|
||||
Name: "Standard",
|
||||
RenewalWindowDays: 30,
|
||||
AutoRenew: true,
|
||||
MaxRetries: 3,
|
||||
RetryInterval: 300,
|
||||
AlertThresholdsDays: []int{30, 14, 7, 0},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
policyRepo.AddPolicy(policy)
|
||||
|
||||
// Add existing renewal job
|
||||
existingJob := &domain.Job{
|
||||
ID: "job-existing",
|
||||
CertificateID: cert.ID,
|
||||
Type: domain.JobTypeRenewal,
|
||||
Status: domain.JobStatusPending,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
jobRepo.AddJob(existingJob)
|
||||
|
||||
// Run first check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Run second check
|
||||
_ = svc.CheckExpiringCertificates(ctx)
|
||||
|
||||
// Should have only 1 renewal job
|
||||
renewalCount := 0
|
||||
for _, job := range jobRepo.Jobs {
|
||||
if job.Type == domain.JobTypeRenewal {
|
||||
renewalCount++
|
||||
}
|
||||
}
|
||||
if renewalCount > 1 {
|
||||
t.Errorf("expected 1 renewal job, got %d (duplicate prevention failed)", renewalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessRenewalJob(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
|
||||
"Email": newMockNotifier(),
|
||||
})
|
||||
|
||||
issuerConnector := &mockIssuerConnector{}
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": issuerConnector,
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create certificate
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-renewal",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{"www.test.example.com"},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
TargetIDs: []string{"target-1", "target-2"},
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 30),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create renewal job
|
||||
job := &domain.Job{
|
||||
ID: "job-renewal-1",
|
||||
CertificateID: cert.ID,
|
||||
Type: domain.JobTypeRenewal,
|
||||
Status: domain.JobStatusPending,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
jobRepo.AddJob(job)
|
||||
|
||||
// Process renewal job
|
||||
err := svc.ProcessRenewalJob(ctx, job)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessRenewalJob failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify cert was updated
|
||||
updated, _ := certRepo.Get(ctx, cert.ID)
|
||||
if updated.Status != domain.CertificateStatusActive {
|
||||
t.Errorf("expected cert status Active, got %s", updated.Status)
|
||||
}
|
||||
|
||||
if updated.LastRenewalAt == nil {
|
||||
t.Errorf("expected LastRenewalAt to be set")
|
||||
}
|
||||
|
||||
// Verify certificate version was created
|
||||
if len(certRepo.Versions[cert.ID]) != 1 {
|
||||
t.Errorf("expected 1 certificate version, got %d", len(certRepo.Versions[cert.ID]))
|
||||
}
|
||||
|
||||
// Verify deployment jobs were created
|
||||
deploymentCount := 0
|
||||
for _, j := range jobRepo.Jobs {
|
||||
if j.Type == domain.JobTypeDeployment {
|
||||
deploymentCount++
|
||||
}
|
||||
}
|
||||
if deploymentCount != 2 {
|
||||
t.Errorf("expected 2 deployment jobs (one per target), got %d", deploymentCount)
|
||||
}
|
||||
|
||||
// Verify job was marked as completed
|
||||
completedJob, _ := jobRepo.Get(ctx, job.ID)
|
||||
if completedJob.Status != domain.JobStatusCompleted {
|
||||
t.Errorf("expected job status Completed, got %s", completedJob.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessRenewalJob_IssuerFailure(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{
|
||||
"Email": newMockNotifier(),
|
||||
})
|
||||
|
||||
// Create issuer that will fail
|
||||
issuerConnector := &mockIssuerConnector{
|
||||
Err: fmt.Errorf("issuer service unavailable"),
|
||||
}
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": issuerConnector,
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create certificate
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-renewal-fail",
|
||||
Name: "Test Cert",
|
||||
CommonName: "test.example.com",
|
||||
SANs: []string{},
|
||||
OwnerID: "owner-1",
|
||||
TeamID: "team-1",
|
||||
IssuerID: "iss-test",
|
||||
RenewalPolicyID: "rp-standard",
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 30),
|
||||
Tags: make(map[string]string),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
certRepo.AddCert(cert)
|
||||
|
||||
// Create renewal job
|
||||
job := &domain.Job{
|
||||
ID: "job-renewal-fail",
|
||||
CertificateID: cert.ID,
|
||||
Type: domain.JobTypeRenewal,
|
||||
Status: domain.JobStatusPending,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
jobRepo.AddJob(job)
|
||||
|
||||
// Process renewal job (should fail)
|
||||
err := svc.ProcessRenewalJob(ctx, job)
|
||||
if err == nil {
|
||||
t.Fatalf("expected ProcessRenewalJob to fail")
|
||||
}
|
||||
|
||||
// Verify job was marked as failed
|
||||
failedJob, _ := jobRepo.Get(ctx, job.ID)
|
||||
if failedJob.Status != domain.JobStatusFailed {
|
||||
t.Errorf("expected job status Failed, got %s", failedJob.Status)
|
||||
}
|
||||
|
||||
if failedJob.LastError == nil || !strings.Contains(*failedJob.LastError, "issuer service unavailable") {
|
||||
t.Errorf("expected error message in job, got: %v", failedJob.LastError)
|
||||
}
|
||||
|
||||
// Verify failure notification was sent
|
||||
if len(notifRepo.Notifications) < 1 {
|
||||
t.Errorf("expected failure notification to be created")
|
||||
}
|
||||
|
||||
foundFailureNotif := false
|
||||
for _, notif := range notifRepo.Notifications {
|
||||
if notif.Type == domain.NotificationTypeRenewalFailure {
|
||||
foundFailureNotif = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundFailureNotif {
|
||||
t.Errorf("expected RenewalFailure notification type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryFailedJobs(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create failed job with attempts < max_attempts
|
||||
failedJob := &domain.Job{
|
||||
ID: "job-failed-1",
|
||||
CertificateID: "mc-test",
|
||||
Type: domain.JobTypeRenewal,
|
||||
Status: domain.JobStatusFailed,
|
||||
Attempts: 1,
|
||||
MaxAttempts: 3,
|
||||
LastError: stringPtr("temporary failure"),
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now().AddDate(0, 0, -1),
|
||||
}
|
||||
jobRepo.AddJob(failedJob)
|
||||
|
||||
// Create other job types that should be ignored
|
||||
otherJob := &domain.Job{
|
||||
ID: "job-other",
|
||||
CertificateID: "mc-test",
|
||||
Type: domain.JobTypeDeployment,
|
||||
Status: domain.JobStatusFailed,
|
||||
Attempts: 1,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
jobRepo.AddJob(otherJob)
|
||||
|
||||
// Retry failed jobs
|
||||
err := svc.RetryFailedJobs(ctx, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("RetryFailedJobs failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify failed renewal job was reset to pending
|
||||
retried, _ := jobRepo.Get(ctx, failedJob.ID)
|
||||
if retried.Status != domain.JobStatusPending {
|
||||
t.Errorf("expected job status Pending after retry, got %s", retried.Status)
|
||||
}
|
||||
|
||||
// Verify other job type was not touched
|
||||
other, _ := jobRepo.Get(ctx, otherJob.ID)
|
||||
if other.Status != domain.JobStatusFailed {
|
||||
t.Errorf("expected non-renewal job to stay Failed, got %s", other.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessRenewalJob_NoCertificate(t *testing.T) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
certRepo := newMockCertificateRepository()
|
||||
jobRepo := newMockJobRepository()
|
||||
policyRepo := newMockRenewalPolicyRepository()
|
||||
auditRepo := newMockAuditRepository()
|
||||
notifRepo := newMockNotificationRepository()
|
||||
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
notifSvc := NewNotificationService(notifRepo, map[string]Notifier{})
|
||||
|
||||
issuerRegistry := map[string]IssuerConnector{
|
||||
"iss-test": &mockIssuerConnector{},
|
||||
}
|
||||
|
||||
svc := NewRenewalService(certRepo, jobRepo, policyRepo, auditSvc, notifSvc, issuerRegistry)
|
||||
|
||||
// Create job with non-existent certificate
|
||||
job := &domain.Job{
|
||||
ID: "job-no-cert",
|
||||
CertificateID: "mc-missing",
|
||||
Type: domain.JobTypeRenewal,
|
||||
Status: domain.JobStatusPending,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
jobRepo.AddJob(job)
|
||||
|
||||
// Process renewal job
|
||||
err := svc.ProcessRenewalJob(ctx, job)
|
||||
if err == nil {
|
||||
t.Fatalf("expected ProcessRenewalJob to fail for missing certificate")
|
||||
}
|
||||
|
||||
// Verify job was marked as failed
|
||||
failedJob, _ := jobRepo.Get(ctx, job.ID)
|
||||
if failedJob.Status != domain.JobStatusFailed {
|
||||
t.Errorf("expected job status Failed, got %s", failedJob.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// stringPtr is defined in notification_test.go
|
||||
@@ -0,0 +1,771 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
var errNotFound = errors.New("not found")
|
||||
|
||||
// mockCertRepo is a test implementation of CertificateRepository
|
||||
type mockCertRepo struct {
|
||||
Certs map[string]*domain.ManagedCertificate
|
||||
Versions map[string][]*domain.CertificateVersion
|
||||
CreateErr error
|
||||
UpdateErr error
|
||||
GetErr error
|
||||
ListErr error
|
||||
ListVersionsErr error
|
||||
ListVersionsResult []*domain.CertificateVersion
|
||||
CreateVersionErr error
|
||||
ArchiveErr error
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, 0, m.ListErr
|
||||
}
|
||||
var certs []*domain.ManagedCertificate
|
||||
for _, c := range m.Certs {
|
||||
certs = append(certs, c)
|
||||
}
|
||||
return certs, len(certs), nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
cert, ok := m.Certs[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Certs[cert.ID] = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
m.Certs[cert.ID] = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) Archive(ctx context.Context, id string) error {
|
||||
if m.ArchiveErr != nil {
|
||||
return m.ArchiveErr
|
||||
}
|
||||
cert, ok := m.Certs[id]
|
||||
if !ok {
|
||||
return errNotFound
|
||||
}
|
||||
cert.Status = domain.CertificateStatusArchived
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
||||
if m.ListVersionsErr != nil {
|
||||
return nil, m.ListVersionsErr
|
||||
}
|
||||
if m.ListVersionsResult != nil {
|
||||
return m.ListVersionsResult, nil
|
||||
}
|
||||
return m.Versions[certID], nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
|
||||
if m.CreateVersionErr != nil {
|
||||
return m.CreateVersionErr
|
||||
}
|
||||
m.Versions[version.CertificateID] = append(m.Versions[version.CertificateID], version)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
|
||||
var expiring []*domain.ManagedCertificate
|
||||
for _, c := range m.Certs {
|
||||
if c.ExpiresAt.Before(before) {
|
||||
expiring = append(expiring, c)
|
||||
}
|
||||
}
|
||||
return expiring, nil
|
||||
}
|
||||
|
||||
func (m *mockCertRepo) AddCert(cert *domain.ManagedCertificate) {
|
||||
m.Certs[cert.ID] = cert
|
||||
}
|
||||
|
||||
// mockJobRepo is a test implementation of JobRepository
|
||||
type mockJobRepo struct {
|
||||
Jobs map[string]*domain.Job
|
||||
StatusUpdates map[string]domain.JobStatus
|
||||
CreateErr error
|
||||
UpdateErr error
|
||||
UpdateStatusErr error
|
||||
GetErr error
|
||||
ListErr error
|
||||
ListByStatusErr error
|
||||
DeleteErr error
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) List(ctx context.Context) ([]*domain.Job, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) Get(ctx context.Context, id string) (*domain.Job, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
job, ok := m.Jobs[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) Create(ctx context.Context, job *domain.Job) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Jobs[job.ID] = job
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) Update(ctx context.Context, job *domain.Job) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
m.Jobs[job.ID] = job
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) Delete(ctx context.Context, id string) error {
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
delete(m.Jobs, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
|
||||
if m.ListByStatusErr != nil {
|
||||
return nil, m.ListByStatusErr
|
||||
}
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
if j.Status == status {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
if j.CertificateID == certID {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
|
||||
if m.UpdateStatusErr != nil {
|
||||
return m.UpdateStatusErr
|
||||
}
|
||||
job, ok := m.Jobs[id]
|
||||
if !ok {
|
||||
return errNotFound
|
||||
}
|
||||
job.Status = status
|
||||
if errMsg != "" {
|
||||
job.LastError = &errMsg
|
||||
}
|
||||
m.StatusUpdates[id] = status
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
|
||||
var jobs []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
if j.Type == jobType && j.Status == domain.JobStatusPending {
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) AddJob(job *domain.Job) {
|
||||
m.Jobs[job.ID] = job
|
||||
}
|
||||
|
||||
// mockNotifRepo is a test implementation of NotificationRepository
|
||||
type mockNotifRepo struct {
|
||||
Notifications []*domain.NotificationEvent
|
||||
CreateErr error
|
||||
ListErr error
|
||||
UpdateErr error
|
||||
}
|
||||
|
||||
func (m *mockNotifRepo) Create(ctx context.Context, notif *domain.NotificationEvent) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Notifications = append(m.Notifications, notif)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockNotifRepo) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
return m.Notifications, nil
|
||||
}
|
||||
|
||||
func (m *mockNotifRepo) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
for _, n := range m.Notifications {
|
||||
if n.ID == id {
|
||||
n.Status = status
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errNotFound
|
||||
}
|
||||
|
||||
func (m *mockNotifRepo) AddNotification(notif *domain.NotificationEvent) {
|
||||
m.Notifications = append(m.Notifications, notif)
|
||||
}
|
||||
|
||||
// mockAuditRepo is a test implementation of AuditRepository
|
||||
type mockAuditRepo struct {
|
||||
Events []*domain.AuditEvent
|
||||
CreateErr error
|
||||
ListErr error
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) Create(ctx context.Context, event *domain.AuditEvent) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Events = append(m.Events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
// Apply filtering like the real repo
|
||||
var filtered []*domain.AuditEvent
|
||||
for _, e := range m.Events {
|
||||
if filter != nil {
|
||||
if filter.ResourceType != "" && e.ResourceType != filter.ResourceType {
|
||||
continue
|
||||
}
|
||||
if filter.ResourceID != "" && e.ResourceID != filter.ResourceID {
|
||||
continue
|
||||
}
|
||||
if filter.Actor != "" && e.Actor != filter.Actor {
|
||||
continue
|
||||
}
|
||||
if !filter.From.IsZero() && e.Timestamp.Before(filter.From) {
|
||||
continue
|
||||
}
|
||||
if !filter.To.IsZero() && e.Timestamp.After(filter.To) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (m *mockAuditRepo) AddEvent(event *domain.AuditEvent) {
|
||||
m.Events = append(m.Events, event)
|
||||
}
|
||||
|
||||
// mockPolicyRepo is a test implementation of PolicyRepository
|
||||
type mockPolicyRepo struct {
|
||||
Rules map[string]*domain.PolicyRule
|
||||
Violations []*domain.PolicyViolation
|
||||
CreateRuleErr error
|
||||
UpdateRuleErr error
|
||||
DeleteRuleErr error
|
||||
GetRuleErr error
|
||||
ListRulesErr error
|
||||
CreateViolationErr error
|
||||
ListViolationsErr error
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
|
||||
if m.ListRulesErr != nil {
|
||||
return nil, m.ListRulesErr
|
||||
}
|
||||
var rules []*domain.PolicyRule
|
||||
for _, r := range m.Rules {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
|
||||
if m.GetRuleErr != nil {
|
||||
return nil, m.GetRuleErr
|
||||
}
|
||||
rule, ok := m.Rules[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
if m.CreateRuleErr != nil {
|
||||
return m.CreateRuleErr
|
||||
}
|
||||
m.Rules[rule.ID] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
if m.UpdateRuleErr != nil {
|
||||
return m.UpdateRuleErr
|
||||
}
|
||||
m.Rules[rule.ID] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) DeleteRule(ctx context.Context, id string) error {
|
||||
if m.DeleteRuleErr != nil {
|
||||
return m.DeleteRuleErr
|
||||
}
|
||||
delete(m.Rules, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
|
||||
if m.CreateViolationErr != nil {
|
||||
return m.CreateViolationErr
|
||||
}
|
||||
m.Violations = append(m.Violations, violation)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
|
||||
if m.ListViolationsErr != nil {
|
||||
return nil, m.ListViolationsErr
|
||||
}
|
||||
return m.Violations, nil
|
||||
}
|
||||
|
||||
func (m *mockPolicyRepo) AddRule(rule *domain.PolicyRule) {
|
||||
m.Rules[rule.ID] = rule
|
||||
}
|
||||
|
||||
// mockRenewalPolicyRepo is a test implementation of RenewalPolicyRepository
|
||||
type mockRenewalPolicyRepo struct {
|
||||
Policies map[string]*domain.RenewalPolicy
|
||||
GetErr error
|
||||
ListErr error
|
||||
}
|
||||
|
||||
func (m *mockRenewalPolicyRepo) Get(ctx context.Context, id string) (*domain.RenewalPolicy, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
policy, ok := m.Policies[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (m *mockRenewalPolicyRepo) List(ctx context.Context) ([]*domain.RenewalPolicy, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var policies []*domain.RenewalPolicy
|
||||
for _, p := range m.Policies {
|
||||
policies = append(policies, p)
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
func (m *mockRenewalPolicyRepo) AddPolicy(policy *domain.RenewalPolicy) {
|
||||
m.Policies[policy.ID] = policy
|
||||
}
|
||||
|
||||
// mockAgentRepo is a test implementation of AgentRepository
|
||||
type mockAgentRepo struct {
|
||||
Agents map[string]*domain.Agent
|
||||
HeartbeatUpdates map[string]time.Time
|
||||
CreateErr error
|
||||
UpdateErr error
|
||||
DeleteErr error
|
||||
GetErr error
|
||||
ListErr error
|
||||
UpdateHeartbeatErr error
|
||||
GetByAPIKeyErr error
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var agents []*domain.Agent
|
||||
for _, a := range m.Agents {
|
||||
agents = append(agents, a)
|
||||
}
|
||||
return agents, nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) Get(ctx context.Context, id string) (*domain.Agent, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
agent, ok := m.Agents[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) Create(ctx context.Context, agent *domain.Agent) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Agents[agent.ID] = agent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
m.Agents[agent.ID] = agent
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) Delete(ctx context.Context, id string) error {
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
delete(m.Agents, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
if m.UpdateHeartbeatErr != nil {
|
||||
return m.UpdateHeartbeatErr
|
||||
}
|
||||
agent, ok := m.Agents[id]
|
||||
if !ok {
|
||||
return errNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
agent.LastHeartbeatAt = &now
|
||||
m.HeartbeatUpdates[id] = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
|
||||
if m.GetByAPIKeyErr != nil {
|
||||
return nil, m.GetByAPIKeyErr
|
||||
}
|
||||
for _, a := range m.Agents {
|
||||
if a.APIKeyHash == keyHash {
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
return nil, errNotFound
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) AddAgent(agent *domain.Agent) {
|
||||
m.Agents[agent.ID] = agent
|
||||
}
|
||||
|
||||
// mockTargetRepo is a test implementation of TargetRepository
|
||||
type mockTargetRepo struct {
|
||||
Targets map[string]*domain.DeploymentTarget
|
||||
CreateErr error
|
||||
UpdateErr error
|
||||
DeleteErr error
|
||||
GetErr error
|
||||
ListErr error
|
||||
ListByCertErr error
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var targets []*domain.DeploymentTarget
|
||||
for _, t := range m.Targets {
|
||||
targets = append(targets, t)
|
||||
}
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
target, ok := m.Targets[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) Create(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.Targets[target.ID] = target
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) Update(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
m.Targets[target.ID] = target
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) Delete(ctx context.Context, id string) error {
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
delete(m.Targets, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
|
||||
if m.ListByCertErr != nil {
|
||||
return nil, m.ListByCertErr
|
||||
}
|
||||
return m.List(ctx)
|
||||
}
|
||||
|
||||
func (m *mockTargetRepo) AddTarget(target *domain.DeploymentTarget) {
|
||||
m.Targets[target.ID] = target
|
||||
}
|
||||
|
||||
// mockIssuerConnector is a test implementation of IssuerConnector
|
||||
type mockIssuerConnector struct {
|
||||
Result *IssuanceResult
|
||||
Err error
|
||||
}
|
||||
|
||||
func (m *mockIssuerConnector) IssueCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error) {
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
if m.Result != nil {
|
||||
return m.Result, nil
|
||||
}
|
||||
now := time.Now()
|
||||
return &IssuanceResult{
|
||||
Serial: "test-serial-123",
|
||||
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
|
||||
NotBefore: now,
|
||||
NotAfter: now.AddDate(1, 0, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerConnector) RenewCertificate(ctx context.Context, commonName string, sans []string, csrPEM string) (*IssuanceResult, error) {
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
return m.IssueCertificate(ctx, commonName, sans, csrPEM)
|
||||
}
|
||||
|
||||
// Constructor functions for mocks
|
||||
|
||||
func newMockCertificateRepository() *mockCertRepo {
|
||||
return &mockCertRepo{
|
||||
Certs: make(map[string]*domain.ManagedCertificate),
|
||||
Versions: make(map[string][]*domain.CertificateVersion),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockJobRepository() *mockJobRepo {
|
||||
return &mockJobRepo{
|
||||
Jobs: make(map[string]*domain.Job),
|
||||
StatusUpdates: make(map[string]domain.JobStatus),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockNotificationRepository() *mockNotifRepo {
|
||||
return &mockNotifRepo{
|
||||
Notifications: make([]*domain.NotificationEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockAuditRepository() *mockAuditRepo {
|
||||
return &mockAuditRepo{
|
||||
Events: make([]*domain.AuditEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockPolicyRepository() *mockPolicyRepo {
|
||||
return &mockPolicyRepo{
|
||||
Rules: make(map[string]*domain.PolicyRule),
|
||||
Violations: make([]*domain.PolicyViolation, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockRenewalPolicyRepository() *mockRenewalPolicyRepo {
|
||||
return &mockRenewalPolicyRepo{
|
||||
Policies: make(map[string]*domain.RenewalPolicy),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockAgentRepository() *mockAgentRepo {
|
||||
return &mockAgentRepo{
|
||||
Agents: make(map[string]*domain.Agent),
|
||||
HeartbeatUpdates: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockTargetRepository() *mockTargetRepo {
|
||||
return &mockTargetRepo{
|
||||
Targets: make(map[string]*domain.DeploymentTarget),
|
||||
}
|
||||
}
|
||||
|
||||
func newMockIssuerRepository() *mockIssuerRepository {
|
||||
return &mockIssuerRepository{
|
||||
issuers: make(map[string]*domain.Issuer),
|
||||
}
|
||||
}
|
||||
|
||||
// mockIssuerRepository is a test implementation of IssuerRepository
|
||||
type mockIssuerRepository struct {
|
||||
issuers map[string]*domain.Issuer
|
||||
GetErr error
|
||||
ListErr error
|
||||
CreateErr error
|
||||
UpdateErr error
|
||||
DeleteErr error
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var issuers []*domain.Issuer
|
||||
for _, i := range m.issuers {
|
||||
issuers = append(issuers, i)
|
||||
}
|
||||
return issuers, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||
if m.GetErr != nil {
|
||||
return nil, m.GetErr
|
||||
}
|
||||
issuer, ok := m.issuers[id]
|
||||
if !ok {
|
||||
return nil, errNotFound
|
||||
}
|
||||
return issuer, nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
|
||||
if m.CreateErr != nil {
|
||||
return m.CreateErr
|
||||
}
|
||||
m.issuers[issuer.ID] = issuer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
|
||||
if m.UpdateErr != nil {
|
||||
return m.UpdateErr
|
||||
}
|
||||
m.issuers[issuer.ID] = issuer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) Delete(ctx context.Context, id string) error {
|
||||
if m.DeleteErr != nil {
|
||||
return m.DeleteErr
|
||||
}
|
||||
delete(m.issuers, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIssuerRepository) AddIssuer(issuer *domain.Issuer) {
|
||||
m.issuers[issuer.ID] = issuer
|
||||
}
|
||||
|
||||
// mockNotifier is a simple notifier for testing
|
||||
type mockNotifier struct {
|
||||
messages []*mockNotifierMessage
|
||||
SendErr error
|
||||
}
|
||||
|
||||
type mockNotifierMessage struct {
|
||||
Recipient string
|
||||
Subject string
|
||||
Body string
|
||||
}
|
||||
|
||||
func newMockNotifier() *mockNotifier {
|
||||
return &mockNotifier{
|
||||
messages: make([]*mockNotifierMessage, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNotifier) Send(ctx context.Context, recipient string, subject string, body string) error {
|
||||
if m.SendErr != nil {
|
||||
return m.SendErr
|
||||
}
|
||||
m.messages = append(m.messages, &mockNotifierMessage{
|
||||
Recipient: recipient,
|
||||
Subject: subject,
|
||||
Body: body,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockNotifier) Channel() string {
|
||||
return "Email"
|
||||
}
|
||||
|
||||
func (m *mockNotifier) getSentCount() int {
|
||||
return len(m.messages)
|
||||
}
|
||||
|
||||
func (m *mockNotifier) getLastMessage() *mockNotifierMessage {
|
||||
if len(m.messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.messages[len(m.messages)-1]
|
||||
}
|
||||
Reference in New Issue
Block a user