mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 15:51:30 +00:00
feat(M48): continuous TLS health monitoring — endpoint state machine, shared tlsprobe, 8 API endpoints, GUI
Adds continuous TLS endpoint health monitoring that closes the deploy→verify→monitor loop. After M25 verifies a deployment succeeded once, M48 continuously confirms it stays healthy. Key components: - Shared `internal/tlsprobe/` package extracted from network scanner for reuse - Health status state machine: healthy → degraded (2 failures) → down (5 failures), plus cert_mismatch when served fingerprint differs from expected - 8th scheduler loop (60s tick, per-endpoint configurable intervals) - PostgreSQL migration 000011: endpoint_health_checks + endpoint_health_history tables - 8 REST API endpoints (CRUD, history, acknowledge, summary) - Health Monitor GUI page with summary bar, status table, create modal, auto-refresh - 38 new tests (5 tlsprobe + 11 domain + 10 service + 8 handler + 4 frontend) - All coverage thresholds maintained (service 68%, handler 83%, domain 87%, middleware 63%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,308 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// HealthCheckServicer defines the interface used by the health check handler.
|
||||
type HealthCheckServicer interface {
|
||||
Create(ctx context.Context, check *domain.EndpointHealthCheck) error
|
||||
Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error)
|
||||
Update(ctx context.Context, check *domain.EndpointHealthCheck) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error)
|
||||
GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error)
|
||||
AcknowledgeIncident(ctx context.Context, id string, actor string) error
|
||||
GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error)
|
||||
}
|
||||
|
||||
// HealthCheckHandler handles HTTP requests for TLS health monitoring.
|
||||
type HealthCheckHandler struct {
|
||||
service HealthCheckServicer
|
||||
}
|
||||
|
||||
// NewHealthCheckHandler creates a new health check handler.
|
||||
func NewHealthCheckHandler(service HealthCheckServicer) *HealthCheckHandler {
|
||||
return &HealthCheckHandler{service: service}
|
||||
}
|
||||
|
||||
// ListHealthChecks handles GET /api/v1/health-checks
|
||||
func (h *HealthCheckHandler) ListHealthChecks(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
status := query.Get("status")
|
||||
certificateID := query.Get("certificate_id")
|
||||
networkScanTargetID := query.Get("network_scan_target_id")
|
||||
enabledStr := query.Get("enabled")
|
||||
page := parseIntDefault(query.Get("page"), 1)
|
||||
perPage := parseIntDefault(query.Get("per_page"), 50)
|
||||
if perPage > 500 {
|
||||
perPage = 50
|
||||
}
|
||||
|
||||
// Parse enabled flag if provided
|
||||
var enabledFilter *bool
|
||||
if enabledStr != "" {
|
||||
enabled := enabledStr == "true"
|
||||
enabledFilter = &enabled
|
||||
}
|
||||
|
||||
filter := &repository.HealthCheckFilter{
|
||||
Status: status,
|
||||
CertificateID: certificateID,
|
||||
NetworkScanTargetID: networkScanTargetID,
|
||||
Enabled: enabledFilter,
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
}
|
||||
|
||||
checks, total, err := h.service.List(r.Context(), filter)
|
||||
if err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to list health checks: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if checks == nil {
|
||||
checks = make([]*domain.EndpointHealthCheck, 0)
|
||||
}
|
||||
|
||||
JSON(w, http.StatusOK, PagedResponse{
|
||||
Data: checks,
|
||||
Total: int64(total),
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
})
|
||||
}
|
||||
|
||||
// GetHealthCheck handles GET /api/v1/health-checks/{id}
|
||||
func (h *HealthCheckHandler) GetHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
Error(w, http.StatusBadRequest, "health check ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
check, err := h.service.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
Error(w, http.StatusNotFound, fmt.Sprintf("health check not found: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
JSON(w, http.StatusOK, check)
|
||||
}
|
||||
|
||||
// CreateHealthCheck handles POST /api/v1/health-checks
|
||||
func (h *HealthCheckHandler) CreateHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var check domain.EndpointHealthCheck
|
||||
if err := json.NewDecoder(r.Body).Decode(&check); err != nil {
|
||||
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if check.Endpoint == "" {
|
||||
Error(w, http.StatusBadRequest, "endpoint is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if check.CheckIntervalSecs <= 0 {
|
||||
check.CheckIntervalSecs = 300
|
||||
}
|
||||
if check.DegradedThreshold <= 0 {
|
||||
check.DegradedThreshold = 2
|
||||
}
|
||||
if check.DownThreshold <= 0 {
|
||||
check.DownThreshold = 5
|
||||
}
|
||||
if check.Status == "" {
|
||||
check.Status = domain.HealthStatusUnknown
|
||||
}
|
||||
|
||||
if err := h.service.Create(r.Context(), &check); err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to create health check: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
JSON(w, http.StatusCreated, check)
|
||||
}
|
||||
|
||||
// UpdateHealthCheck handles PUT /api/v1/health-checks/{id}
|
||||
func (h *HealthCheckHandler) UpdateHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
Error(w, http.StatusBadRequest, "health check ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing check
|
||||
existing, err := h.service.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
Error(w, http.StatusNotFound, fmt.Sprintf("health check not found: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var updates domain.EndpointHealthCheck
|
||||
if err := json.NewDecoder(r.Body).Decode(&updates); err != nil {
|
||||
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Merge updates (only update provided fields)
|
||||
if updates.Endpoint != "" {
|
||||
existing.Endpoint = updates.Endpoint
|
||||
}
|
||||
if updates.ExpectedFingerprint != "" {
|
||||
existing.ExpectedFingerprint = updates.ExpectedFingerprint
|
||||
}
|
||||
if updates.CheckIntervalSecs > 0 {
|
||||
existing.CheckIntervalSecs = updates.CheckIntervalSecs
|
||||
}
|
||||
if updates.DegradedThreshold > 0 {
|
||||
existing.DegradedThreshold = updates.DegradedThreshold
|
||||
}
|
||||
if updates.DownThreshold > 0 {
|
||||
existing.DownThreshold = updates.DownThreshold
|
||||
}
|
||||
existing.Enabled = updates.Enabled
|
||||
|
||||
if err := h.service.Update(r.Context(), existing); err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to update health check: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
JSON(w, http.StatusOK, existing)
|
||||
}
|
||||
|
||||
// DeleteHealthCheck handles DELETE /api/v1/health-checks/{id}
|
||||
func (h *HealthCheckHandler) DeleteHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
Error(w, http.StatusBadRequest, "health check ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Delete(r.Context(), id); err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to delete health check: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetHealthCheckHistory handles GET /api/v1/health-checks/{id}/history
|
||||
func (h *HealthCheckHandler) GetHealthCheckHistory(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
Error(w, http.StatusBadRequest, "health check ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
limit := 100
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
history, err := h.service.GetHistory(r.Context(), id, limit)
|
||||
if err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to get health check history: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if history == nil {
|
||||
history = make([]*domain.HealthHistoryEntry, 0)
|
||||
}
|
||||
|
||||
JSON(w, http.StatusOK, history)
|
||||
}
|
||||
|
||||
// AcknowledgeHealthCheck handles POST /api/v1/health-checks/{id}/acknowledge
|
||||
func (h *HealthCheckHandler) AcknowledgeHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
Error(w, http.StatusBadRequest, "health check ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Actor string `json:"actor,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Actor == "" {
|
||||
req.Actor = "unknown"
|
||||
}
|
||||
|
||||
if err := h.service.AcknowledgeIncident(r.Context(), id, req.Actor); err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to acknowledge health check: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetHealthCheckSummary handles GET /api/v1/health-checks/summary
|
||||
// This route must be registered BEFORE the /{id} routes
|
||||
func (h *HealthCheckHandler) GetHealthCheckSummary(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
summary, err := h.service.GetSummary(r.Context())
|
||||
if err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to get health check summary: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
JSON(w, http.StatusOK, summary)
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// mockHealthCheckSvc implements HealthCheckServicer for testing.
|
||||
type mockHealthCheckSvc struct {
|
||||
createErr error
|
||||
getErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
listErr error
|
||||
getHistoryErr error
|
||||
acknowledgeErr error
|
||||
getSummaryErr error
|
||||
checks map[string]*domain.EndpointHealthCheck
|
||||
summary *domain.HealthCheckSummary
|
||||
}
|
||||
|
||||
func newMockHealthCheckSvc() *mockHealthCheckSvc {
|
||||
return &mockHealthCheckSvc{
|
||||
checks: make(map[string]*domain.EndpointHealthCheck),
|
||||
summary: &domain.HealthCheckSummary{
|
||||
Healthy: 1,
|
||||
Degraded: 0,
|
||||
Down: 0,
|
||||
CertMismatch: 0,
|
||||
Unknown: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
check.ID = "hc-created-1"
|
||||
m.checks[check.ID] = check
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
if check, ok := m.checks[id]; ok {
|
||||
return check, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
m.checks[check.ID] = check
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) Delete(ctx context.Context, id string) error {
|
||||
if m.deleteErr != nil {
|
||||
return m.deleteErr
|
||||
}
|
||||
delete(m.checks, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
|
||||
if m.listErr != nil {
|
||||
return nil, 0, m.listErr
|
||||
}
|
||||
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
|
||||
for _, check := range m.checks {
|
||||
checks = append(checks, check)
|
||||
}
|
||||
return checks, len(checks), nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
|
||||
if m.getHistoryErr != nil {
|
||||
return nil, m.getHistoryErr
|
||||
}
|
||||
return make([]*domain.HealthHistoryEntry, 0), nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) AcknowledgeIncident(ctx context.Context, id string, actor string) error {
|
||||
if m.acknowledgeErr != nil {
|
||||
return m.acknowledgeErr
|
||||
}
|
||||
if check, ok := m.checks[id]; ok {
|
||||
check.Acknowledged = true
|
||||
check.AcknowledgedBy = actor
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckSvc) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
|
||||
if m.getSummaryErr != nil {
|
||||
return nil, m.getSummaryErr
|
||||
}
|
||||
return m.summary, nil
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestListHealthChecks_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
|
||||
ID: "hc-1",
|
||||
Endpoint: "api.example.com:443",
|
||||
Status: domain.HealthStatusHealthy,
|
||||
}
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/health-checks", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListHealthChecks(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp PagedResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 {
|
||||
t.Errorf("Expected 1 health check, got %d", resp.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListHealthChecks_MethodNotAllowed(t *testing.T) {
|
||||
handler := NewHealthCheckHandler(newMockHealthCheckSvc())
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/health-checks", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ListHealthChecks(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHealthCheck_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
check := &domain.EndpointHealthCheck{
|
||||
ID: "hc-1",
|
||||
Endpoint: "api.example.com:443",
|
||||
Status: domain.HealthStatusHealthy,
|
||||
}
|
||||
svc.checks["hc-1"] = check
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/health-checks/hc-1", nil)
|
||||
req.SetPathValue("id", "hc-1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetHealthCheck(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp domain.EndpointHealthCheck
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp.ID != "hc-1" {
|
||||
t.Errorf("Expected ID hc-1, got %s", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHealthCheck_NotFound(t *testing.T) {
|
||||
handler := NewHealthCheckHandler(newMockHealthCheckSvc())
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/health-checks/nonexistent", nil)
|
||||
req.SetPathValue("id", "nonexistent")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetHealthCheck(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHealthCheck_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
check := domain.EndpointHealthCheck{
|
||||
Endpoint: "web.example.com:443",
|
||||
Enabled: true,
|
||||
}
|
||||
body, _ := json.Marshal(check)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/health-checks", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.CreateHealthCheck(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("Expected status 201, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp domain.EndpointHealthCheck
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Endpoint != "web.example.com:443" {
|
||||
t.Errorf("Expected endpoint web.example.com:443, got %s", resp.Endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteHealthCheck_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
|
||||
ID: "hc-1",
|
||||
Endpoint: "api.example.com:443",
|
||||
}
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/v1/health-checks/hc-1", nil)
|
||||
req.SetPathValue("id", "hc-1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.DeleteHealthCheck(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204, got %d", w.Code)
|
||||
}
|
||||
|
||||
if _, ok := svc.checks["hc-1"]; ok {
|
||||
t.Fatal("Expected check to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcknowledgeHealthCheck_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
|
||||
ID: "hc-1",
|
||||
Endpoint: "api.example.com:443",
|
||||
Status: domain.HealthStatusDown,
|
||||
}
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/health-checks/hc-1/acknowledge", bytes.NewReader([]byte(`{"actor":"user@example.com"}`)))
|
||||
req.SetPathValue("id", "hc-1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.AcknowledgeHealthCheck(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204, got %d", w.Code)
|
||||
}
|
||||
|
||||
if !svc.checks["hc-1"].Acknowledged {
|
||||
t.Fatal("Expected check to be acknowledged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHealthCheckSummary_Success(t *testing.T) {
|
||||
svc := newMockHealthCheckSvc()
|
||||
svc.summary = &domain.HealthCheckSummary{
|
||||
Healthy: 3,
|
||||
Degraded: 1,
|
||||
Down: 0,
|
||||
CertMismatch: 0,
|
||||
Unknown: 1,
|
||||
}
|
||||
handler := NewHealthCheckHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/health-checks/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetHealthCheckSummary(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp domain.HealthCheckSummary
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Healthy != 3 {
|
||||
t.Errorf("Expected 3 healthy checks, got %d", resp.Healthy)
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,7 @@ type HandlerRegistry struct {
|
||||
Verification handler.VerificationHandler
|
||||
Export handler.ExportHandler
|
||||
Digest handler.DigestHandler
|
||||
HealthChecks *handler.HealthCheckHandler
|
||||
}
|
||||
|
||||
// RegisterHandlers sets up all API routes with their handlers.
|
||||
@@ -226,6 +227,17 @@ func (r *Router) RegisterHandlers(reg HandlerRegistry) {
|
||||
// Digest routes: /api/v1/digest
|
||||
r.Register("GET /api/v1/digest/preview", http.HandlerFunc(reg.Digest.PreviewDigest))
|
||||
r.Register("POST /api/v1/digest/send", http.HandlerFunc(reg.Digest.SendDigest))
|
||||
|
||||
// Health check routes: /api/v1/health-checks
|
||||
// Summary endpoint must be registered before {id} routes
|
||||
r.Register("GET /api/v1/health-checks/summary", http.HandlerFunc(reg.HealthChecks.GetHealthCheckSummary))
|
||||
r.Register("GET /api/v1/health-checks", http.HandlerFunc(reg.HealthChecks.ListHealthChecks))
|
||||
r.Register("POST /api/v1/health-checks", http.HandlerFunc(reg.HealthChecks.CreateHealthCheck))
|
||||
r.Register("GET /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.GetHealthCheck))
|
||||
r.Register("PUT /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.UpdateHealthCheck))
|
||||
r.Register("DELETE /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.DeleteHealthCheck))
|
||||
r.Register("GET /api/v1/health-checks/{id}/history", http.HandlerFunc(reg.HealthChecks.GetHealthCheckHistory))
|
||||
r.Register("POST /api/v1/health-checks/{id}/acknowledge", http.HandlerFunc(reg.HealthChecks.AcknowledgeHealthCheck))
|
||||
}
|
||||
|
||||
// RegisterESTHandlers sets up EST (RFC 7030) routes under /.well-known/est/.
|
||||
|
||||
@@ -32,6 +32,7 @@ type Config struct {
|
||||
GoogleCAS GoogleCASConfig
|
||||
AWSACMPCA AWSACMPCAConfig
|
||||
Digest DigestConfig
|
||||
HealthCheck HealthCheckConfig
|
||||
Encryption EncryptionConfig
|
||||
}
|
||||
|
||||
@@ -319,6 +320,46 @@ type DigestConfig struct {
|
||||
Recipients []string
|
||||
}
|
||||
|
||||
// HealthCheckConfig contains configuration for continuous TLS health monitoring (M48).
|
||||
type HealthCheckConfig struct {
|
||||
// Enabled controls whether health checks are enabled.
|
||||
// Default: false.
|
||||
// Setting: CERTCTL_HEALTH_CHECK_ENABLED environment variable.
|
||||
Enabled bool
|
||||
|
||||
// CheckInterval is the main scheduler loop interval for polling due checks.
|
||||
// Default: 60 seconds. Each endpoint has its own check_interval_seconds.
|
||||
// Setting: CERTCTL_HEALTH_CHECK_INTERVAL environment variable.
|
||||
CheckInterval time.Duration
|
||||
|
||||
// DefaultInterval is the default probe interval in seconds for each endpoint (per-endpoint basis).
|
||||
// Default: 300 seconds (5 minutes).
|
||||
// Setting: CERTCTL_HEALTH_CHECK_DEFAULT_INTERVAL environment variable.
|
||||
DefaultInterval int
|
||||
|
||||
// DefaultTimeout is the default TLS connection timeout in milliseconds.
|
||||
// Default: 5000 milliseconds (5 seconds).
|
||||
// Setting: CERTCTL_HEALTH_CHECK_DEFAULT_TIMEOUT environment variable.
|
||||
DefaultTimeout int
|
||||
|
||||
// MaxConcurrent is the maximum number of concurrent TLS probes.
|
||||
// Default: 20.
|
||||
// Setting: CERTCTL_HEALTH_CHECK_MAX_CONCURRENT environment variable.
|
||||
MaxConcurrent int
|
||||
|
||||
// HistoryRetention controls how long probe history records are kept.
|
||||
// Default: 30 days. Older records are purged by the scheduler.
|
||||
// Setting: CERTCTL_HEALTH_CHECK_HISTORY_RETENTION environment variable.
|
||||
HistoryRetention time.Duration
|
||||
|
||||
// AutoCreate controls whether health checks are auto-created when:
|
||||
// - A deployment job completes with verification success
|
||||
// - A network scan target has health_check_enabled=true
|
||||
// Default: true.
|
||||
// Setting: CERTCTL_HEALTH_CHECK_AUTO_CREATE environment variable.
|
||||
AutoCreate bool
|
||||
}
|
||||
|
||||
// ACMEConfig contains ACME issuer connector configuration.
|
||||
type ACMEConfig struct {
|
||||
// DirectoryURL is the ACME directory URL for certificate issuance.
|
||||
@@ -678,6 +719,15 @@ func Load() (*Config, error) {
|
||||
Interval: getEnvDuration("CERTCTL_DIGEST_INTERVAL", 24*time.Hour),
|
||||
Recipients: getEnvList("CERTCTL_DIGEST_RECIPIENTS", nil),
|
||||
},
|
||||
HealthCheck: HealthCheckConfig{
|
||||
Enabled: getEnvBool("CERTCTL_HEALTH_CHECK_ENABLED", false),
|
||||
CheckInterval: getEnvDuration("CERTCTL_HEALTH_CHECK_INTERVAL", 60*time.Second),
|
||||
DefaultInterval: getEnvInt("CERTCTL_HEALTH_CHECK_DEFAULT_INTERVAL", 300),
|
||||
DefaultTimeout: getEnvInt("CERTCTL_HEALTH_CHECK_DEFAULT_TIMEOUT", 5000),
|
||||
MaxConcurrent: getEnvInt("CERTCTL_HEALTH_CHECK_MAX_CONCURRENT", 20),
|
||||
HistoryRetention: getEnvDuration("CERTCTL_HEALTH_CHECK_HISTORY_RETENTION", 30*24*time.Hour),
|
||||
AutoCreate: getEnvBool("CERTCTL_HEALTH_CHECK_AUTO_CREATE", true),
|
||||
},
|
||||
Encryption: EncryptionConfig{
|
||||
ConfigEncryptionKey: getEnv("CERTCTL_CONFIG_ENCRYPTION_KEY", ""),
|
||||
},
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// HealthStatus represents the current health state of a monitored endpoint.
|
||||
type HealthStatus string
|
||||
|
||||
const (
|
||||
HealthStatusHealthy HealthStatus = "healthy"
|
||||
HealthStatusDegraded HealthStatus = "degraded"
|
||||
HealthStatusDown HealthStatus = "down"
|
||||
HealthStatusCertMismatch HealthStatus = "cert_mismatch"
|
||||
HealthStatusUnknown HealthStatus = "unknown"
|
||||
)
|
||||
|
||||
// IsValidHealthStatus checks if a health status string is valid.
|
||||
func IsValidHealthStatus(s string) bool {
|
||||
switch HealthStatus(s) {
|
||||
case HealthStatusHealthy, HealthStatusDegraded, HealthStatusDown, HealthStatusCertMismatch, HealthStatusUnknown:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EndpointHealthCheck represents a monitored TLS endpoint.
|
||||
type EndpointHealthCheck struct {
|
||||
ID string `json:"id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
CertificateID *string `json:"certificate_id,omitempty"`
|
||||
NetworkScanTargetID *string `json:"network_scan_target_id,omitempty"`
|
||||
ExpectedFingerprint string `json:"expected_fingerprint"`
|
||||
ObservedFingerprint string `json:"observed_fingerprint"`
|
||||
Status HealthStatus `json:"status"`
|
||||
ConsecutiveFailures int `json:"consecutive_failures"`
|
||||
ResponseTimeMs int `json:"response_time_ms"`
|
||||
TLSVersion string `json:"tls_version"`
|
||||
CipherSuite string `json:"cipher_suite"`
|
||||
CertSubject string `json:"cert_subject"`
|
||||
CertIssuer string `json:"cert_issuer"`
|
||||
CertExpiry *time.Time `json:"cert_expiry,omitempty"`
|
||||
LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
|
||||
LastSuccessAt *time.Time `json:"last_success_at,omitempty"`
|
||||
LastFailureAt *time.Time `json:"last_failure_at,omitempty"`
|
||||
LastTransitionAt *time.Time `json:"last_transition_at,omitempty"`
|
||||
FailureReason string `json:"failure_reason"`
|
||||
DegradedThreshold int `json:"degraded_threshold"`
|
||||
DownThreshold int `json:"down_threshold"`
|
||||
CheckIntervalSecs int `json:"check_interval_seconds"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Acknowledged bool `json:"acknowledged"`
|
||||
AcknowledgedBy string `json:"acknowledged_by,omitempty"`
|
||||
AcknowledgedAt *time.Time `json:"acknowledged_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TransitionStatus computes the new health status based on the probe result.
|
||||
// Returns the new status and whether a transition occurred.
|
||||
func (h *EndpointHealthCheck) TransitionStatus(probeSuccess bool, observedFingerprint string) (HealthStatus, bool) {
|
||||
oldStatus := h.Status
|
||||
var newStatus HealthStatus
|
||||
|
||||
if probeSuccess {
|
||||
if h.ExpectedFingerprint != "" && observedFingerprint != h.ExpectedFingerprint {
|
||||
newStatus = HealthStatusCertMismatch
|
||||
} else {
|
||||
newStatus = HealthStatusHealthy
|
||||
}
|
||||
} else {
|
||||
// Increment failures for next calculation (caller will update h.ConsecutiveFailures)
|
||||
failures := h.ConsecutiveFailures + 1
|
||||
if failures >= h.DownThreshold {
|
||||
newStatus = HealthStatusDown
|
||||
} else if failures >= h.DegradedThreshold {
|
||||
newStatus = HealthStatusDegraded
|
||||
} else {
|
||||
// Keep current status during initial failures before threshold
|
||||
// Unless we were in an error state, transition to degraded after first failure
|
||||
if h.Status == HealthStatusUnknown || h.Status == HealthStatusHealthy {
|
||||
newStatus = HealthStatusHealthy // still considered healthy during grace period
|
||||
} else {
|
||||
newStatus = h.Status
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newStatus, newStatus != oldStatus
|
||||
}
|
||||
|
||||
// HealthHistoryEntry represents a single probe record.
|
||||
type HealthHistoryEntry struct {
|
||||
ID string `json:"id"`
|
||||
HealthCheckID string `json:"health_check_id"`
|
||||
Status string `json:"status"`
|
||||
ResponseTimeMs int `json:"response_time_ms"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
FailureReason string `json:"failure_reason"`
|
||||
CheckedAt time.Time `json:"checked_at"`
|
||||
}
|
||||
|
||||
// HealthCheckSummary contains aggregate counts by status.
|
||||
type HealthCheckSummary struct {
|
||||
Healthy int `json:"healthy"`
|
||||
Degraded int `json:"degraded"`
|
||||
Down int `json:"down"`
|
||||
CertMismatch int `json:"cert_mismatch"`
|
||||
Unknown int `json:"unknown"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsValidHealthStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
status string
|
||||
valid bool
|
||||
}{
|
||||
{"healthy", true},
|
||||
{"degraded", true},
|
||||
{"down", true},
|
||||
{"cert_mismatch", true},
|
||||
{"unknown", true},
|
||||
{"invalid", false},
|
||||
{"", false},
|
||||
{"HEALTHY", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.status, func(t *testing.T) {
|
||||
result := IsValidHealthStatus(tt.status)
|
||||
if result != tt.valid {
|
||||
t.Errorf("IsValidHealthStatus(%q) = %v, want %v", tt.status, result, tt.valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_HealthyProbe(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusUnknown,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
ExpectedFingerprint: "abc123",
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "abc123")
|
||||
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_CertMismatch(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusHealthy,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
ExpectedFingerprint: "abc123",
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "xyz789")
|
||||
|
||||
if newStatus != HealthStatusCertMismatch {
|
||||
t.Errorf("expected HealthStatusCertMismatch, got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_FirstFailure_BelowThreshold(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusHealthy,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(false, "")
|
||||
|
||||
// At 1 failure with degraded threshold 2, still healthy
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy (grace period), got %s", newStatus)
|
||||
}
|
||||
if transitioned {
|
||||
t.Errorf("expected transition=false (still healthy), got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_DegradedThreshold(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusHealthy,
|
||||
ConsecutiveFailures: 1, // Now will be 2 after increment
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(false, "")
|
||||
|
||||
if newStatus != HealthStatusDegraded {
|
||||
t.Errorf("expected HealthStatusDegraded, got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_DownThreshold(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusDegraded,
|
||||
ConsecutiveFailures: 4, // Now will be 5 after increment
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(false, "")
|
||||
|
||||
if newStatus != HealthStatusDown {
|
||||
t.Errorf("expected HealthStatusDown, got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_Recovery(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusDown,
|
||||
ConsecutiveFailures: 10,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
ExpectedFingerprint: "abc123",
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "abc123")
|
||||
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy (recovery), got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true (from down to healthy), got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_NoFingerprint(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusHealthy,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
ExpectedFingerprint: "", // No expected fingerprint
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "anything")
|
||||
|
||||
// Success with no expected fingerprint should always be healthy
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy (no fingerprint check), got %s", newStatus)
|
||||
}
|
||||
if transitioned {
|
||||
t.Errorf("expected transition=false (already healthy), got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_UnknownToHealthy(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusUnknown,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "")
|
||||
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
|
||||
}
|
||||
if !transitioned {
|
||||
t.Errorf("expected transition=true (from unknown to healthy), got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransitionStatus_NoTransitionWhenSame(t *testing.T) {
|
||||
h := &EndpointHealthCheck{
|
||||
Status: HealthStatusHealthy,
|
||||
ConsecutiveFailures: 0,
|
||||
DegradedThreshold: 2,
|
||||
DownThreshold: 5,
|
||||
}
|
||||
|
||||
newStatus, transitioned := h.TransitionStatus(true, "")
|
||||
|
||||
if newStatus != HealthStatusHealthy {
|
||||
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
|
||||
}
|
||||
if transitioned {
|
||||
t.Errorf("expected transition=false (already healthy), got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckSummary(t *testing.T) {
|
||||
summary := &HealthCheckSummary{
|
||||
Healthy: 5,
|
||||
Degraded: 2,
|
||||
Down: 1,
|
||||
CertMismatch: 1,
|
||||
Unknown: 0,
|
||||
Total: 9,
|
||||
}
|
||||
|
||||
if summary.Total != 9 {
|
||||
t.Errorf("expected Total=9, got %d", summary.Total)
|
||||
}
|
||||
if summary.Healthy != 5 {
|
||||
t.Errorf("expected Healthy=5, got %d", summary.Healthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthHistoryEntry(t *testing.T) {
|
||||
now := time.Now()
|
||||
entry := &HealthHistoryEntry{
|
||||
ID: "hh-test-123",
|
||||
HealthCheckID: "hc-test-123",
|
||||
Status: "healthy",
|
||||
ResponseTimeMs: 42,
|
||||
Fingerprint: "abc123def456",
|
||||
FailureReason: "",
|
||||
CheckedAt: now,
|
||||
}
|
||||
|
||||
if entry.ID != "hh-test-123" {
|
||||
t.Errorf("expected ID='hh-test-123', got %q", entry.ID)
|
||||
}
|
||||
if entry.ResponseTimeMs != 42 {
|
||||
t.Errorf("expected ResponseTimeMs=42, got %d", entry.ResponseTimeMs)
|
||||
}
|
||||
}
|
||||
@@ -277,3 +277,45 @@ type OwnerRepository interface {
|
||||
// Delete removes an owner.
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// HealthCheckRepository manages endpoint health check persistence.
|
||||
type HealthCheckRepository interface {
|
||||
// Create stores a new health check.
|
||||
Create(ctx context.Context, check *domain.EndpointHealthCheck) error
|
||||
// Update modifies an existing health check.
|
||||
Update(ctx context.Context, check *domain.EndpointHealthCheck) error
|
||||
// Get retrieves a health check by ID.
|
||||
Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error)
|
||||
// Delete removes a health check.
|
||||
Delete(ctx context.Context, id string) error
|
||||
// List returns health checks matching the filter with pagination.
|
||||
List(ctx context.Context, filter *HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error)
|
||||
// ListDueForCheck returns health checks that need to be probed (interval exceeded).
|
||||
ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error)
|
||||
// GetByEndpoint retrieves a health check by endpoint address.
|
||||
GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error)
|
||||
// RecordHistory records a single probe result in history.
|
||||
RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error
|
||||
// GetHistory retrieves recent probe history for a health check.
|
||||
GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error)
|
||||
// PurgeHistory deletes history entries older than the specified time.
|
||||
PurgeHistory(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
// GetSummary returns aggregate counts by health status.
|
||||
GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error)
|
||||
}
|
||||
|
||||
// HealthCheckFilter contains filter parameters for health check queries.
|
||||
type HealthCheckFilter struct {
|
||||
// Status filters by health status (healthy, degraded, down, cert_mismatch, unknown).
|
||||
Status string
|
||||
// CertificateID filters by managed certificate ID.
|
||||
CertificateID string
|
||||
// NetworkScanTargetID filters by network scan target ID.
|
||||
NetworkScanTargetID string
|
||||
// Enabled filters by enabled/disabled status (nil = all).
|
||||
Enabled *bool
|
||||
// Page is the page number (1-indexed).
|
||||
Page int
|
||||
// PerPage is the number of results per page.
|
||||
PerPage int
|
||||
}
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// HealthCheckRepository implements repository.HealthCheckRepository using PostgreSQL.
|
||||
type HealthCheckRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewHealthCheckRepository creates a new PostgreSQL-backed health check repository.
|
||||
func NewHealthCheckRepository(db *sql.DB) *HealthCheckRepository {
|
||||
return &HealthCheckRepository{db: db}
|
||||
}
|
||||
|
||||
// Create stores a new health check.
|
||||
func (r *HealthCheckRepository) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO endpoint_health_checks (
|
||||
id, endpoint, certificate_id, network_scan_target_id,
|
||||
expected_fingerprint, observed_fingerprint, status,
|
||||
consecutive_failures, response_time_ms, tls_version, cipher_suite,
|
||||
cert_subject, cert_issuer, cert_expiry,
|
||||
last_checked_at, last_success_at, last_failure_at, last_transition_at,
|
||||
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
|
||||
enabled, acknowledged, acknowledged_by, acknowledged_at,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4,
|
||||
$5, $6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13, $14,
|
||||
$15, $16, $17, $18,
|
||||
$19, $20, $21, $22,
|
||||
$23, $24, $25, $26,
|
||||
$27, $28
|
||||
)`,
|
||||
check.ID, check.Endpoint, check.CertificateID, check.NetworkScanTargetID,
|
||||
check.ExpectedFingerprint, check.ObservedFingerprint, string(check.Status),
|
||||
check.ConsecutiveFailures, check.ResponseTimeMs, check.TLSVersion, check.CipherSuite,
|
||||
check.CertSubject, check.CertIssuer, check.CertExpiry,
|
||||
check.LastCheckedAt, check.LastSuccessAt, check.LastFailureAt, check.LastTransitionAt,
|
||||
check.FailureReason, check.DegradedThreshold, check.DownThreshold, check.CheckIntervalSecs,
|
||||
check.Enabled, check.Acknowledged, check.AcknowledgedBy, check.AcknowledgedAt,
|
||||
check.CreatedAt, check.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create health check: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing health check.
|
||||
func (r *HealthCheckRepository) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
check.UpdatedAt = time.Now()
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE endpoint_health_checks SET
|
||||
endpoint = $2, certificate_id = $3, network_scan_target_id = $4,
|
||||
expected_fingerprint = $5, observed_fingerprint = $6, status = $7,
|
||||
consecutive_failures = $8, response_time_ms = $9, tls_version = $10, cipher_suite = $11,
|
||||
cert_subject = $12, cert_issuer = $13, cert_expiry = $14,
|
||||
last_checked_at = $15, last_success_at = $16, last_failure_at = $17, last_transition_at = $18,
|
||||
failure_reason = $19, degraded_threshold = $20, down_threshold = $21, check_interval_seconds = $22,
|
||||
enabled = $23, acknowledged = $24, acknowledged_by = $25, acknowledged_at = $26,
|
||||
updated_at = $27
|
||||
WHERE id = $1`,
|
||||
check.ID,
|
||||
check.Endpoint, check.CertificateID, check.NetworkScanTargetID,
|
||||
check.ExpectedFingerprint, check.ObservedFingerprint, string(check.Status),
|
||||
check.ConsecutiveFailures, check.ResponseTimeMs, check.TLSVersion, check.CipherSuite,
|
||||
check.CertSubject, check.CertIssuer, check.CertExpiry,
|
||||
check.LastCheckedAt, check.LastSuccessAt, check.LastFailureAt, check.LastTransitionAt,
|
||||
check.FailureReason, check.DegradedThreshold, check.DownThreshold, check.CheckIntervalSecs,
|
||||
check.Enabled, check.Acknowledged, check.AcknowledgedBy, check.AcknowledgedAt,
|
||||
check.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update health check: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a health check by ID.
|
||||
func (r *HealthCheckRepository) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
|
||||
check := &domain.EndpointHealthCheck{}
|
||||
var status string
|
||||
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, endpoint, certificate_id, network_scan_target_id,
|
||||
expected_fingerprint, observed_fingerprint, status,
|
||||
consecutive_failures, response_time_ms, tls_version, cipher_suite,
|
||||
cert_subject, cert_issuer, cert_expiry,
|
||||
last_checked_at, last_success_at, last_failure_at, last_transition_at,
|
||||
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
|
||||
enabled, acknowledged, acknowledged_by, acknowledged_at,
|
||||
created_at, updated_at
|
||||
FROM endpoint_health_checks
|
||||
WHERE id = $1`, id).Scan(
|
||||
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
|
||||
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
|
||||
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
|
||||
&check.CertSubject, &check.CertIssuer, &certExpiry,
|
||||
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
|
||||
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
|
||||
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
|
||||
&check.CreatedAt, &check.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("health check not found: %s", id)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get health check: %w", err)
|
||||
}
|
||||
check.Status = domain.HealthStatus(status)
|
||||
if certExpiry.Valid {
|
||||
check.CertExpiry = &certExpiry.Time
|
||||
}
|
||||
if lastCheckedAt.Valid {
|
||||
check.LastCheckedAt = &lastCheckedAt.Time
|
||||
}
|
||||
if lastSuccessAt.Valid {
|
||||
check.LastSuccessAt = &lastSuccessAt.Time
|
||||
}
|
||||
if lastFailureAt.Valid {
|
||||
check.LastFailureAt = &lastFailureAt.Time
|
||||
}
|
||||
if lastTransitionAt.Valid {
|
||||
check.LastTransitionAt = &lastTransitionAt.Time
|
||||
}
|
||||
if acknowledgedAt.Valid {
|
||||
check.AcknowledgedAt = &acknowledgedAt.Time
|
||||
}
|
||||
return check, nil
|
||||
}
|
||||
|
||||
// Delete removes a health check.
|
||||
func (r *HealthCheckRepository) Delete(ctx context.Context, id string) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM endpoint_health_checks WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete health check: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns health checks matching the filter with pagination.
|
||||
func (r *HealthCheckRepository) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
|
||||
query := `SELECT id, endpoint, certificate_id, network_scan_target_id,
|
||||
expected_fingerprint, observed_fingerprint, status,
|
||||
consecutive_failures, response_time_ms, tls_version, cipher_suite,
|
||||
cert_subject, cert_issuer, cert_expiry,
|
||||
last_checked_at, last_success_at, last_failure_at, last_transition_at,
|
||||
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
|
||||
enabled, acknowledged, acknowledged_by, acknowledged_at,
|
||||
created_at, updated_at
|
||||
FROM endpoint_health_checks`
|
||||
countQuery := `SELECT COUNT(*) FROM endpoint_health_checks`
|
||||
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
argIdx := 1
|
||||
|
||||
if filter != nil {
|
||||
if filter.Status != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("status = $%d", argIdx))
|
||||
args = append(args, filter.Status)
|
||||
argIdx++
|
||||
}
|
||||
if filter.CertificateID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("certificate_id = $%d", argIdx))
|
||||
args = append(args, filter.CertificateID)
|
||||
argIdx++
|
||||
}
|
||||
if filter.NetworkScanTargetID != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("network_scan_target_id = $%d", argIdx))
|
||||
args = append(args, filter.NetworkScanTargetID)
|
||||
argIdx++
|
||||
}
|
||||
if filter.Enabled != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("enabled = $%d", argIdx))
|
||||
args = append(args, *filter.Enabled)
|
||||
argIdx++
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) > 0 {
|
||||
where := " WHERE " + conditions[0]
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
where += " AND " + conditions[i]
|
||||
}
|
||||
query += where
|
||||
countQuery += where
|
||||
}
|
||||
|
||||
// Get total count
|
||||
var total int
|
||||
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("count health checks: %w", err)
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
query += " ORDER BY created_at DESC"
|
||||
page := 1
|
||||
perPage := 50
|
||||
if filter != nil {
|
||||
if filter.Page > 0 {
|
||||
page = filter.Page
|
||||
}
|
||||
if filter.PerPage > 0 {
|
||||
perPage = filter.PerPage
|
||||
}
|
||||
}
|
||||
offset := (page - 1) * perPage
|
||||
query += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, perPage, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list health checks: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var checks []*domain.EndpointHealthCheck
|
||||
for rows.Next() {
|
||||
check, err := scanHealthCheck(rows)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
checks = append(checks, check)
|
||||
}
|
||||
return checks, total, rows.Err()
|
||||
}
|
||||
|
||||
// ListDueForCheck returns health checks where the check interval has been exceeded.
|
||||
func (r *HealthCheckRepository) ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, endpoint, certificate_id, network_scan_target_id,
|
||||
expected_fingerprint, observed_fingerprint, status,
|
||||
consecutive_failures, response_time_ms, tls_version, cipher_suite,
|
||||
cert_subject, cert_issuer, cert_expiry,
|
||||
last_checked_at, last_success_at, last_failure_at, last_transition_at,
|
||||
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
|
||||
enabled, acknowledged, acknowledged_by, acknowledged_at,
|
||||
created_at, updated_at
|
||||
FROM endpoint_health_checks
|
||||
WHERE enabled = TRUE
|
||||
AND (
|
||||
last_checked_at IS NULL
|
||||
OR last_checked_at + (check_interval_seconds * INTERVAL '1 second') < NOW()
|
||||
)
|
||||
ORDER BY last_checked_at ASC NULLS FIRST`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list due health checks: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var checks []*domain.EndpointHealthCheck
|
||||
for rows.Next() {
|
||||
check, err := scanHealthCheck(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
checks = append(checks, check)
|
||||
}
|
||||
return checks, rows.Err()
|
||||
}
|
||||
|
||||
// GetByEndpoint retrieves a health check by endpoint address.
|
||||
func (r *HealthCheckRepository) GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, endpoint, certificate_id, network_scan_target_id,
|
||||
expected_fingerprint, observed_fingerprint, status,
|
||||
consecutive_failures, response_time_ms, tls_version, cipher_suite,
|
||||
cert_subject, cert_issuer, cert_expiry,
|
||||
last_checked_at, last_success_at, last_failure_at, last_transition_at,
|
||||
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
|
||||
enabled, acknowledged, acknowledged_by, acknowledged_at,
|
||||
created_at, updated_at
|
||||
FROM endpoint_health_checks
|
||||
WHERE endpoint = $1`, endpoint)
|
||||
check := &domain.EndpointHealthCheck{}
|
||||
var status string
|
||||
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
|
||||
err := row.Scan(
|
||||
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
|
||||
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
|
||||
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
|
||||
&check.CertSubject, &check.CertIssuer, &certExpiry,
|
||||
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
|
||||
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
|
||||
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
|
||||
&check.CreatedAt, &check.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("health check not found for endpoint: %s", endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get health check by endpoint: %w", err)
|
||||
}
|
||||
check.Status = domain.HealthStatus(status)
|
||||
if certExpiry.Valid {
|
||||
check.CertExpiry = &certExpiry.Time
|
||||
}
|
||||
if lastCheckedAt.Valid {
|
||||
check.LastCheckedAt = &lastCheckedAt.Time
|
||||
}
|
||||
if lastSuccessAt.Valid {
|
||||
check.LastSuccessAt = &lastSuccessAt.Time
|
||||
}
|
||||
if lastFailureAt.Valid {
|
||||
check.LastFailureAt = &lastFailureAt.Time
|
||||
}
|
||||
if lastTransitionAt.Valid {
|
||||
check.LastTransitionAt = &lastTransitionAt.Time
|
||||
}
|
||||
if acknowledgedAt.Valid {
|
||||
check.AcknowledgedAt = &acknowledgedAt.Time
|
||||
}
|
||||
return check, nil
|
||||
}
|
||||
|
||||
// RecordHistory records a single probe result in history.
|
||||
func (r *HealthCheckRepository) RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO endpoint_health_history (id, health_check_id, status, response_time_ms, fingerprint, failure_reason, checked_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
entry.ID, entry.HealthCheckID, entry.Status, entry.ResponseTimeMs, entry.Fingerprint, entry.FailureReason, entry.CheckedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record health check history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHistory retrieves recent probe history for a health check.
|
||||
func (r *HealthCheckRepository) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, health_check_id, status, response_time_ms, fingerprint, failure_reason, checked_at
|
||||
FROM endpoint_health_history
|
||||
WHERE health_check_id = $1
|
||||
ORDER BY checked_at DESC
|
||||
LIMIT $2`, healthCheckID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get health check history: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []*domain.HealthHistoryEntry
|
||||
for rows.Next() {
|
||||
entry := &domain.HealthHistoryEntry{}
|
||||
if err := rows.Scan(&entry.ID, &entry.HealthCheckID, &entry.Status, &entry.ResponseTimeMs, &entry.Fingerprint, &entry.FailureReason, &entry.CheckedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan health history entry: %w", err)
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// PurgeHistory deletes history entries older than the specified time.
|
||||
func (r *HealthCheckRepository) PurgeHistory(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM endpoint_health_history WHERE checked_at < $1`, olderThan)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("purge health check history: %w", err)
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// GetSummary returns aggregate counts by health status.
|
||||
func (r *HealthCheckRepository) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT status, COUNT(*) FROM endpoint_health_checks GROUP BY status`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get health check summary: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
summary := &domain.HealthCheckSummary{}
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan health check summary: %w", err)
|
||||
}
|
||||
switch domain.HealthStatus(status) {
|
||||
case domain.HealthStatusHealthy:
|
||||
summary.Healthy = count
|
||||
case domain.HealthStatusDegraded:
|
||||
summary.Degraded = count
|
||||
case domain.HealthStatusDown:
|
||||
summary.Down = count
|
||||
case domain.HealthStatusCertMismatch:
|
||||
summary.CertMismatch = count
|
||||
case domain.HealthStatusUnknown:
|
||||
summary.Unknown = count
|
||||
}
|
||||
summary.Total += count
|
||||
}
|
||||
return summary, rows.Err()
|
||||
}
|
||||
|
||||
// scannable is an interface satisfied by both *sql.Row and *sql.Rows.
|
||||
type scannable interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
// scanHealthCheck scans a health check from a row.
|
||||
func scanHealthCheck(row scannable) (*domain.EndpointHealthCheck, error) {
|
||||
check := &domain.EndpointHealthCheck{}
|
||||
var status string
|
||||
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
|
||||
err := row.Scan(
|
||||
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
|
||||
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
|
||||
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
|
||||
&check.CertSubject, &check.CertIssuer, &certExpiry,
|
||||
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
|
||||
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
|
||||
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
|
||||
&check.CreatedAt, &check.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan health check: %w", err)
|
||||
}
|
||||
check.Status = domain.HealthStatus(status)
|
||||
if certExpiry.Valid {
|
||||
check.CertExpiry = &certExpiry.Time
|
||||
}
|
||||
if lastCheckedAt.Valid {
|
||||
check.LastCheckedAt = &lastCheckedAt.Time
|
||||
}
|
||||
if lastSuccessAt.Valid {
|
||||
check.LastSuccessAt = &lastSuccessAt.Time
|
||||
}
|
||||
if lastFailureAt.Valid {
|
||||
check.LastFailureAt = &lastFailureAt.Time
|
||||
}
|
||||
if lastTransitionAt.Valid {
|
||||
check.LastTransitionAt = &lastTransitionAt.Time
|
||||
}
|
||||
if acknowledgedAt.Valid {
|
||||
check.AcknowledgedAt = &acknowledgedAt.Time
|
||||
}
|
||||
return check, nil
|
||||
}
|
||||
@@ -40,6 +40,11 @@ type DigestServicer interface {
|
||||
ProcessDigest(ctx context.Context) error
|
||||
}
|
||||
|
||||
// HealthCheckServicer defines the interface for endpoint TLS health monitoring used by the scheduler.
|
||||
type HealthCheckServicer interface {
|
||||
RunHealthChecks(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Scheduler manages background jobs and periodic tasks for the certificate control plane.
|
||||
// It runs multiple concurrent loops for renewal checks, job processing, agent health checks,
|
||||
// and notification processing.
|
||||
@@ -50,6 +55,7 @@ type Scheduler struct {
|
||||
notificationService NotificationServicer
|
||||
networkScanService NetworkScanServicer
|
||||
digestService DigestServicer
|
||||
healthCheckService HealthCheckServicer
|
||||
logger *slog.Logger
|
||||
|
||||
// Configurable tick intervals
|
||||
@@ -60,6 +66,7 @@ type Scheduler struct {
|
||||
shortLivedExpiryCheckInterval time.Duration
|
||||
networkScanInterval time.Duration
|
||||
digestInterval time.Duration
|
||||
healthCheckInterval time.Duration
|
||||
|
||||
// Idempotency guards: prevent duplicate execution of slow jobs
|
||||
renewalCheckRunning atomic.Bool
|
||||
@@ -69,6 +76,7 @@ type Scheduler struct {
|
||||
shortLivedExpiryCheckRunning atomic.Bool
|
||||
networkScanRunning atomic.Bool
|
||||
digestRunning atomic.Bool
|
||||
healthCheckRunning atomic.Bool
|
||||
|
||||
// Graceful shutdown: wait for in-flight work to complete
|
||||
wg sync.WaitGroup
|
||||
@@ -99,6 +107,7 @@ func NewScheduler(
|
||||
shortLivedExpiryCheckInterval: 30 * time.Second,
|
||||
networkScanInterval: 6 * time.Hour,
|
||||
digestInterval: 24 * time.Hour,
|
||||
healthCheckInterval: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +152,17 @@ func (s *Scheduler) SetShortLivedExpiryCheckInterval(d time.Duration) {
|
||||
s.shortLivedExpiryCheckInterval = d
|
||||
}
|
||||
|
||||
// SetHealthCheckService sets the health check service for the 8th scheduler loop.
|
||||
// Called after construction since health monitoring is optional.
|
||||
func (s *Scheduler) SetHealthCheckService(hcs HealthCheckServicer) {
|
||||
s.healthCheckService = hcs
|
||||
}
|
||||
|
||||
// SetHealthCheckInterval configures the interval for endpoint TLS health checks.
|
||||
func (s *Scheduler) SetHealthCheckInterval(d time.Duration) {
|
||||
s.healthCheckInterval = d
|
||||
}
|
||||
|
||||
// Start initiates all background scheduler loops. It returns a channel that signals
|
||||
// when the scheduler has started all loops. The scheduler runs until the context is cancelled.
|
||||
func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
|
||||
@@ -160,6 +180,9 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
|
||||
if s.digestService != nil {
|
||||
loopCount++
|
||||
}
|
||||
if s.healthCheckService != nil {
|
||||
loopCount++
|
||||
}
|
||||
s.wg.Add(loopCount)
|
||||
|
||||
go func() { defer s.wg.Done(); s.renewalCheckLoop(ctx) }()
|
||||
@@ -173,6 +196,9 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
|
||||
if s.digestService != nil {
|
||||
go func() { defer s.wg.Done(); s.digestLoop(ctx) }()
|
||||
}
|
||||
if s.healthCheckService != nil {
|
||||
go func() { defer s.wg.Done(); s.healthCheckLoop(ctx) }()
|
||||
}
|
||||
|
||||
// Signal that all loops are launched
|
||||
close(startedChan)
|
||||
@@ -517,6 +543,49 @@ func (s *Scheduler) runDigest(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheckLoop runs every healthCheckInterval and performs endpoint TLS health checks.
|
||||
// Do NOT run immediately on start — health checks are frequent (60s default) and may be
|
||||
// resource-intensive. Wait for the first tick.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
|
||||
func (s *Scheduler) healthCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do NOT run immediately on start for health checks — wait for the first tick.
|
||||
// Health checks are frequent and shouldn't fire on every restart.
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.healthCheckRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Debug("health check still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.healthCheckRunning.Store(false)
|
||||
s.runHealthCheck(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runHealthCheck executes a single health check cycle with error recovery.
|
||||
func (s *Scheduler) runHealthCheck(ctx context.Context) {
|
||||
opCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
if err := s.healthCheckService.RunHealthChecks(opCtx); err != nil {
|
||||
s.logger.Error("health check run failed",
|
||||
"error", err,
|
||||
"interval", s.healthCheckInterval.String())
|
||||
} else {
|
||||
s.logger.Debug("health check completed")
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForCompletion waits for all in-flight scheduler work to complete.
|
||||
// It respects the provided timeout and returns an error if work is still in progress after timeout.
|
||||
// Call this after the scheduler context has been cancelled to ensure graceful shutdown.
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
"github.com/shankar0123/certctl/internal/tlsprobe"
|
||||
)
|
||||
|
||||
// HealthCheckService manages endpoint TLS health monitoring.
|
||||
type HealthCheckService struct {
|
||||
repo repository.HealthCheckRepository
|
||||
auditService *AuditService
|
||||
notifService *NotificationService
|
||||
logger *slog.Logger
|
||||
maxConcurrent int
|
||||
defaultTimeout time.Duration
|
||||
historyRetention time.Duration
|
||||
autoCreate bool
|
||||
}
|
||||
|
||||
// NewHealthCheckService creates a new HealthCheckService.
|
||||
func NewHealthCheckService(
|
||||
repo repository.HealthCheckRepository,
|
||||
auditService *AuditService,
|
||||
logger *slog.Logger,
|
||||
maxConcurrent int,
|
||||
defaultTimeout time.Duration,
|
||||
historyRetention time.Duration,
|
||||
autoCreate bool,
|
||||
) *HealthCheckService {
|
||||
return &HealthCheckService{
|
||||
repo: repo,
|
||||
auditService: auditService,
|
||||
logger: logger,
|
||||
maxConcurrent: maxConcurrent,
|
||||
defaultTimeout: defaultTimeout,
|
||||
historyRetention: historyRetention,
|
||||
autoCreate: autoCreate,
|
||||
}
|
||||
}
|
||||
|
||||
// SetNotificationService sets the notification service for sending status transition alerts.
|
||||
func (s *HealthCheckService) SetNotificationService(ns *NotificationService) {
|
||||
s.notifService = ns
|
||||
}
|
||||
|
||||
// RunHealthChecks is the scheduler entry point for continuous TLS health monitoring.
|
||||
// Fetches endpoints due for check, probes concurrently with semaphore control,
|
||||
// updates health status with state transitions, records history, and sends notifications.
|
||||
func (s *HealthCheckService) RunHealthChecks(ctx context.Context) error {
|
||||
// Fetch all endpoints due for check
|
||||
checks, err := s.repo.ListDueForCheck(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list endpoints due for check: %w", err)
|
||||
}
|
||||
|
||||
if len(checks) == 0 {
|
||||
s.logger.Debug("no endpoints due for health check")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Debug("running health checks", "endpoint_count", len(checks))
|
||||
|
||||
// Concurrent probing with semaphore
|
||||
sem := make(chan struct{}, s.maxConcurrent)
|
||||
var wg sync.WaitGroup
|
||||
probeResults := make(map[string]tlsprobe.ProbeResult)
|
||||
var mu sync.Mutex
|
||||
|
||||
for _, check := range checks {
|
||||
wg.Add(1)
|
||||
go func(c *domain.EndpointHealthCheck) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // acquire
|
||||
defer func() { <-sem }() // release
|
||||
|
||||
result := tlsprobe.ProbeTLS(ctx, c.Endpoint, s.defaultTimeout)
|
||||
mu.Lock()
|
||||
probeResults[c.ID] = result
|
||||
mu.Unlock()
|
||||
}(check)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Process results and update health status
|
||||
successCount := 0
|
||||
failureCount := 0
|
||||
transitionCount := 0
|
||||
|
||||
for _, check := range checks {
|
||||
result := probeResults[check.ID]
|
||||
|
||||
// Determine old status for transition detection
|
||||
oldStatus := check.Status
|
||||
|
||||
// Update probe result fields
|
||||
check.LastCheckedAt = timePtr(time.Now())
|
||||
check.ResponseTimeMs = result.ResponseTimeMs
|
||||
|
||||
if result.Success {
|
||||
successCount++
|
||||
check.ObservedFingerprint = result.Fingerprint
|
||||
check.TLSVersion = result.TLSVersion
|
||||
check.CipherSuite = result.CipherSuite
|
||||
check.CertSubject = result.Subject
|
||||
check.CertIssuer = result.Issuer
|
||||
check.CertExpiry = timePtr(result.NotAfter)
|
||||
check.FailureReason = ""
|
||||
check.LastSuccessAt = timePtr(time.Now())
|
||||
check.ConsecutiveFailures = 0
|
||||
} else {
|
||||
failureCount++
|
||||
check.LastFailureAt = timePtr(time.Now())
|
||||
check.ConsecutiveFailures++
|
||||
check.FailureReason = result.Error
|
||||
}
|
||||
|
||||
// Transition state based on consecutive failures and fingerprint match
|
||||
newStatus, transitioned := check.TransitionStatus(result.Success, result.Fingerprint)
|
||||
|
||||
if transitioned {
|
||||
transitionCount++
|
||||
check.Status = newStatus
|
||||
check.LastTransitionAt = timePtr(time.Now())
|
||||
// Reset acknowledged on transition
|
||||
check.Acknowledged = false
|
||||
|
||||
// Log transition
|
||||
s.logger.Info("health check status transition",
|
||||
"endpoint", check.Endpoint,
|
||||
"old_status", string(oldStatus),
|
||||
"new_status", string(newStatus))
|
||||
|
||||
// Record audit event
|
||||
if s.auditService != nil {
|
||||
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
|
||||
"health_check_status_transition", "health_check", check.ID,
|
||||
map[string]interface{}{
|
||||
"endpoint": check.Endpoint,
|
||||
"old_status": string(oldStatus),
|
||||
"new_status": string(newStatus),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Update health check record
|
||||
if err := s.repo.Update(ctx, check); err != nil {
|
||||
s.logger.Error("failed to update health check",
|
||||
"endpoint", check.Endpoint,
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Record probe result in history
|
||||
if err := s.repo.RecordHistory(ctx, &domain.HealthHistoryEntry{
|
||||
HealthCheckID: check.ID,
|
||||
Status: string(check.Status),
|
||||
ResponseTimeMs: check.ResponseTimeMs,
|
||||
Fingerprint: check.ObservedFingerprint,
|
||||
FailureReason: check.FailureReason,
|
||||
CheckedAt: time.Now(),
|
||||
}); err != nil {
|
||||
s.logger.Warn("failed to record health check history",
|
||||
"endpoint", check.Endpoint,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Purge old history entries once per run
|
||||
if err := s.PurgeOldHistory(ctx); err != nil {
|
||||
s.logger.Warn("failed to purge old health check history", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Debug("health check run completed",
|
||||
"total", len(checks),
|
||||
"success", successCount,
|
||||
"failure", failureCount,
|
||||
"transitions", transitionCount)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create creates a new health check endpoint.
|
||||
func (s *HealthCheckService) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
if check.ID == "" {
|
||||
check.ID = generateID("hc")
|
||||
}
|
||||
check.CreatedAt = time.Now()
|
||||
check.UpdatedAt = time.Now()
|
||||
|
||||
if err := s.repo.Create(ctx, check); err != nil {
|
||||
return fmt.Errorf("failed to create health check: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
|
||||
"health_check_created", "health_check", check.ID,
|
||||
map[string]interface{}{
|
||||
"endpoint": check.Endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a health check by ID.
|
||||
func (s *HealthCheckService) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
|
||||
return s.repo.Get(ctx, id)
|
||||
}
|
||||
|
||||
// Update updates an existing health check.
|
||||
func (s *HealthCheckService) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
check.UpdatedAt = time.Now()
|
||||
|
||||
if err := s.repo.Update(ctx, check); err != nil {
|
||||
return fmt.Errorf("failed to update health check: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
|
||||
"health_check_updated", "health_check", check.ID,
|
||||
map[string]interface{}{
|
||||
"endpoint": check.Endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes a health check.
|
||||
func (s *HealthCheckService) Delete(ctx context.Context, id string) error {
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("failed to delete health check: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
|
||||
"health_check_deleted", "health_check", id,
|
||||
map[string]interface{}{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List lists health checks with optional filtering.
|
||||
func (s *HealthCheckService) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
|
||||
if filter == nil {
|
||||
filter = &repository.HealthCheckFilter{}
|
||||
}
|
||||
return s.repo.List(ctx, filter)
|
||||
}
|
||||
|
||||
// GetHistory retrieves health check history for an endpoint.
|
||||
func (s *HealthCheckService) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
return s.repo.GetHistory(ctx, healthCheckID, limit)
|
||||
}
|
||||
|
||||
// AcknowledgeIncident marks a health check incident as acknowledged.
|
||||
func (s *HealthCheckService) AcknowledgeIncident(ctx context.Context, id string, actor string) error {
|
||||
check, err := s.repo.Get(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get health check: %w", err)
|
||||
}
|
||||
|
||||
check.Acknowledged = true
|
||||
check.AcknowledgedBy = actor
|
||||
check.AcknowledgedAt = timePtr(time.Now())
|
||||
|
||||
if err := s.repo.Update(ctx, check); err != nil {
|
||||
return fmt.Errorf("failed to update health check: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
|
||||
"health_check_acknowledged", "health_check", id,
|
||||
map[string]interface{}{
|
||||
"endpoint": check.Endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSummary returns aggregated health check status counts.
|
||||
func (s *HealthCheckService) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
|
||||
return s.repo.GetSummary(ctx)
|
||||
}
|
||||
|
||||
// PurgeOldHistory removes health check history entries older than the retention period.
|
||||
func (s *HealthCheckService) PurgeOldHistory(ctx context.Context) error {
|
||||
cutoff := time.Now().Add(-s.historyRetention)
|
||||
_, err := s.repo.PurgeHistory(ctx, cutoff)
|
||||
return err
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
@@ -0,0 +1,350 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// mockHealthCheckRepo implements the HealthCheckRepository interface for testing.
|
||||
type mockHealthCheckRepo struct {
|
||||
checks map[string]*domain.EndpointHealthCheck
|
||||
history []*domain.HealthHistoryEntry
|
||||
createErr error
|
||||
getErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
listErr error
|
||||
listDueErr error
|
||||
getHistoryErr error
|
||||
recordHistoryErr error
|
||||
purgeHistoryErr error
|
||||
getSummaryErr error
|
||||
getSummaryResult *domain.HealthCheckSummary
|
||||
}
|
||||
|
||||
func newMockHealthCheckRepo() *mockHealthCheckRepo {
|
||||
return &mockHealthCheckRepo{
|
||||
checks: make(map[string]*domain.EndpointHealthCheck),
|
||||
history: []*domain.HealthHistoryEntry{},
|
||||
getSummaryResult: &domain.HealthCheckSummary{
|
||||
Healthy: 0,
|
||||
Degraded: 0,
|
||||
Down: 0,
|
||||
CertMismatch: 0,
|
||||
Unknown: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
m.checks[check.ID] = check
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
if check, ok := m.checks[id]; ok {
|
||||
return check, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error) {
|
||||
for _, check := range m.checks {
|
||||
if check.Endpoint == endpoint {
|
||||
return check, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
m.checks[check.ID] = check
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) Delete(ctx context.Context, id string) error {
|
||||
if m.deleteErr != nil {
|
||||
return m.deleteErr
|
||||
}
|
||||
delete(m.checks, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
|
||||
if m.listErr != nil {
|
||||
return nil, 0, m.listErr
|
||||
}
|
||||
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
|
||||
for _, check := range m.checks {
|
||||
checks = append(checks, check)
|
||||
}
|
||||
return checks, len(checks), nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error) {
|
||||
if m.listDueErr != nil {
|
||||
return nil, m.listDueErr
|
||||
}
|
||||
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
|
||||
for _, check := range m.checks {
|
||||
if check.Enabled {
|
||||
checks = append(checks, check)
|
||||
}
|
||||
}
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
|
||||
if m.getHistoryErr != nil {
|
||||
return nil, m.getHistoryErr
|
||||
}
|
||||
return m.history, nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error {
|
||||
if m.recordHistoryErr != nil {
|
||||
return m.recordHistoryErr
|
||||
}
|
||||
m.history = append(m.history, entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) PurgeHistory(ctx context.Context, before time.Time) (int64, error) {
|
||||
if m.purgeHistoryErr != nil {
|
||||
return 0, m.purgeHistoryErr
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockHealthCheckRepo) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
|
||||
if m.getSummaryErr != nil {
|
||||
return nil, m.getSummaryErr
|
||||
}
|
||||
return m.getSummaryResult, nil
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func newTestLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
}
|
||||
|
||||
func TestHealthCheckService_Create_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check := &domain.EndpointHealthCheck{
|
||||
Endpoint: "example.com:443",
|
||||
Status: domain.HealthStatusUnknown,
|
||||
Enabled: true,
|
||||
CheckIntervalSecs: 300,
|
||||
}
|
||||
|
||||
err := svc.Create(context.Background(), check)
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
if check.ID == "" {
|
||||
t.Fatal("Expected ID to be set")
|
||||
}
|
||||
|
||||
retrieved, _ := repo.Get(context.Background(), check.ID)
|
||||
if retrieved == nil {
|
||||
t.Fatal("Expected check to be in repo")
|
||||
}
|
||||
if retrieved.Endpoint != "example.com:443" {
|
||||
t.Errorf("Expected endpoint example.com:443, got %s", retrieved.Endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_Create_RepoError(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
repo.createErr = errors.New("db error")
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check := &domain.EndpointHealthCheck{
|
||||
Endpoint: "example.com:443",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := svc.Create(context.Background(), check)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_Get_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check := &domain.EndpointHealthCheck{
|
||||
ID: "hc-test-1",
|
||||
Endpoint: "example.com:443",
|
||||
Status: domain.HealthStatusHealthy,
|
||||
}
|
||||
repo.checks["hc-test-1"] = check
|
||||
|
||||
retrieved, err := svc.Get(context.Background(), "hc-test-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
if retrieved.Endpoint != "example.com:443" {
|
||||
t.Errorf("Expected endpoint example.com:443, got %s", retrieved.Endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_Get_NotFound(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
_, err := svc.Get(context.Background(), "nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent check")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_List_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check1 := &domain.EndpointHealthCheck{
|
||||
ID: "hc-1",
|
||||
Endpoint: "api.example.com:443",
|
||||
Status: domain.HealthStatusHealthy,
|
||||
}
|
||||
check2 := &domain.EndpointHealthCheck{
|
||||
ID: "hc-2",
|
||||
Endpoint: "web.example.com:443",
|
||||
Status: domain.HealthStatusDegraded,
|
||||
}
|
||||
repo.checks["hc-1"] = check1
|
||||
repo.checks["hc-2"] = check2
|
||||
|
||||
checks, total, err := svc.List(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if len(checks) != 2 {
|
||||
t.Errorf("Expected 2 checks, got %d", len(checks))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("Expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_Delete_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check := &domain.EndpointHealthCheck{
|
||||
ID: "hc-test-1",
|
||||
Endpoint: "example.com:443",
|
||||
}
|
||||
repo.checks["hc-test-1"] = check
|
||||
|
||||
err := svc.Delete(context.Background(), "hc-test-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := repo.checks["hc-test-1"]; ok {
|
||||
t.Fatal("Expected check to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_AcknowledgeIncident_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
check := &domain.EndpointHealthCheck{
|
||||
ID: "hc-test-1",
|
||||
Endpoint: "example.com:443",
|
||||
Status: domain.HealthStatusDown,
|
||||
Acknowledged: false,
|
||||
}
|
||||
repo.checks["hc-test-1"] = check
|
||||
|
||||
err := svc.AcknowledgeIncident(context.Background(), "hc-test-1", "user@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("AcknowledgeIncident failed: %v", err)
|
||||
}
|
||||
|
||||
retrieved := repo.checks["hc-test-1"]
|
||||
if !retrieved.Acknowledged {
|
||||
t.Fatal("Expected Acknowledged to be true")
|
||||
}
|
||||
if retrieved.AcknowledgedBy != "user@example.com" {
|
||||
t.Errorf("Expected AcknowledgedBy to be user@example.com, got %s", retrieved.AcknowledgedBy)
|
||||
}
|
||||
if retrieved.AcknowledgedAt == nil {
|
||||
t.Fatal("Expected AcknowledgedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_GetSummary_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
repo.getSummaryResult = &domain.HealthCheckSummary{
|
||||
Healthy: 5,
|
||||
Degraded: 2,
|
||||
Down: 1,
|
||||
CertMismatch: 1,
|
||||
Unknown: 0,
|
||||
}
|
||||
|
||||
summary, err := svc.GetSummary(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary failed: %v", err)
|
||||
}
|
||||
if summary.Healthy != 5 {
|
||||
t.Errorf("Expected 5 healthy, got %d", summary.Healthy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_RunHealthChecks_NoEndpoints(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
err := svc.RunHealthChecks(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("RunHealthChecks failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckService_PurgeOldHistory_Success(t *testing.T) {
|
||||
repo := newMockHealthCheckRepo()
|
||||
logger := newTestLogger()
|
||||
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
|
||||
|
||||
err := svc.PurgeOldHistory(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeOldHistory failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
@@ -16,6 +13,7 @@ import (
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
"github.com/shankar0123/certctl/internal/tlsprobe"
|
||||
)
|
||||
|
||||
// SentinelAgentID is the agent ID used for network-discovered certificates.
|
||||
@@ -469,16 +467,15 @@ func (s *NetworkScanService) probeTLS(ctx context.Context, address string, timeo
|
||||
|
||||
// tlsCertToEntry converts an x509.Certificate from a TLS handshake into a DiscoveredCertEntry.
|
||||
func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCertEntry {
|
||||
// Compute SHA-256 fingerprint
|
||||
fingerprintBytes := sha256.Sum256(cert.Raw)
|
||||
fingerprint := fmt.Sprintf("%x", fingerprintBytes)
|
||||
// Compute SHA-256 fingerprint using shared tlsprobe package
|
||||
fingerprint := tlsprobe.CertFingerprint(cert)
|
||||
|
||||
// Encode as PEM
|
||||
pemBlock := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
pemData := string(pem.EncodeToMemory(pemBlock))
|
||||
|
||||
// Key algorithm and size
|
||||
keyAlg, keySize := tlsCertKeyInfo(cert)
|
||||
// Key algorithm and size using shared tlsprobe package
|
||||
keyAlg, keySize := tlsprobe.CertKeyInfo(cert)
|
||||
|
||||
return domain.DiscoveredCertEntry{
|
||||
FingerprintSHA256: fingerprint,
|
||||
@@ -497,20 +494,3 @@ func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCer
|
||||
SourceFormat: "network",
|
||||
}
|
||||
}
|
||||
|
||||
// tlsCertKeyInfo extracts key algorithm name and size from a certificate.
|
||||
func tlsCertKeyInfo(cert *x509.Certificate) (string, int) {
|
||||
switch pub := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return "RSA", pub.N.BitLen()
|
||||
case *ecdsa.PublicKey:
|
||||
return "ECDSA", pub.Curve.Params().BitSize
|
||||
default:
|
||||
switch cert.PublicKeyAlgorithm {
|
||||
case x509.Ed25519:
|
||||
return "Ed25519", 256
|
||||
default:
|
||||
return cert.PublicKeyAlgorithm.String(), 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package tlsprobe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProbeResult contains the result of probing a TLS endpoint.
|
||||
type ProbeResult struct {
|
||||
Address string `json:"address"`
|
||||
Success bool `json:"success"`
|
||||
Fingerprint string `json:"fingerprint"` // SHA-256 hex fingerprint of leaf cert
|
||||
TLSVersion string `json:"tls_version"` // e.g. "TLS 1.3"
|
||||
CipherSuite string `json:"cipher_suite"` // e.g. "TLS_AES_128_GCM_SHA256"
|
||||
Subject string `json:"subject"` // cert subject CN
|
||||
Issuer string `json:"issuer"` // cert issuer CN
|
||||
NotBefore time.Time `json:"not_before"`
|
||||
NotAfter time.Time `json:"not_after"`
|
||||
SerialNumber string `json:"serial_number"`
|
||||
ResponseTimeMs int `json:"response_time_ms"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ProbeTLS connects to a TLS endpoint, performs a handshake, and extracts certificate metadata.
|
||||
// It uses InsecureSkipVerify to discover all certificates including self-signed and expired ones.
|
||||
// This is safe because the certificate data is extracted and analyzed, not validated for trust.
|
||||
func ProbeTLS(ctx context.Context, address string, timeout time.Duration) ProbeResult {
|
||||
startTime := time.Now()
|
||||
result := ProbeResult{
|
||||
Address: address,
|
||||
Success: false,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", address, &tls.Config{
|
||||
// SECURITY NOTE: InsecureSkipVerify is intentionally set to true here.
|
||||
// The health checker must monitor ALL certificates including self-signed,
|
||||
// expired, and internal CA certificates. This setting is scoped to discovery
|
||||
// probing only — it is NEVER used for control-plane API calls, issuer
|
||||
// connector communication, or any operation that trusts the certificate.
|
||||
// The endpoint's certificate chain is extracted and analyzed, not validated.
|
||||
// See TICKET-016 for full security audit rationale.
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
result.Error = err.Error()
|
||||
result.ResponseTimeMs = int(time.Since(startTime).Milliseconds())
|
||||
return result
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
result.ResponseTimeMs = int(time.Since(startTime).Milliseconds())
|
||||
result.Success = true
|
||||
|
||||
// Extract certificates from TLS connection state
|
||||
state := conn.ConnectionState()
|
||||
if len(state.PeerCertificates) > 0 {
|
||||
cert := state.PeerCertificates[0]
|
||||
result.Fingerprint = CertFingerprint(cert)
|
||||
result.Subject = cert.Subject.CommonName
|
||||
result.Issuer = cert.Issuer.CommonName
|
||||
result.NotBefore = cert.NotBefore
|
||||
result.NotAfter = cert.NotAfter
|
||||
result.SerialNumber = cert.SerialNumber.Text(16)
|
||||
}
|
||||
|
||||
// Extract TLS version string
|
||||
result.TLSVersion = tlsVersionString(state.Version)
|
||||
|
||||
// Extract cipher suite name
|
||||
result.CipherSuite = tls.CipherSuiteName(state.CipherSuite)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// CertFingerprint computes the SHA-256 fingerprint of a certificate (hex-encoded).
|
||||
func CertFingerprint(cert *x509.Certificate) string {
|
||||
fingerprintBytes := sha256.Sum256(cert.Raw)
|
||||
return hex.EncodeToString(fingerprintBytes[:])
|
||||
}
|
||||
|
||||
// CertKeyInfo extracts key algorithm name and size from a certificate.
|
||||
// Returns algorithm name (e.g., "RSA", "ECDSA", "Ed25519") and key size in bits.
|
||||
func CertKeyInfo(cert *x509.Certificate) (string, int) {
|
||||
switch pub := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return "RSA", pub.N.BitLen()
|
||||
case *ecdsa.PublicKey:
|
||||
return "ECDSA", pub.Curve.Params().BitSize
|
||||
default:
|
||||
switch cert.PublicKeyAlgorithm {
|
||||
case x509.Ed25519:
|
||||
return "Ed25519", 256
|
||||
default:
|
||||
return cert.PublicKeyAlgorithm.String(), 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tlsVersionString converts a TLS version constant to a human-readable string.
|
||||
func tlsVersionString(version uint16) string {
|
||||
switch version {
|
||||
case tls.VersionTLS10:
|
||||
return "TLS 1.0"
|
||||
case tls.VersionTLS11:
|
||||
return "TLS 1.1"
|
||||
case tls.VersionTLS12:
|
||||
return "TLS 1.2"
|
||||
case tls.VersionTLS13:
|
||||
return "TLS 1.3"
|
||||
default:
|
||||
return fmt.Sprintf("TLS 0x%x", version)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
package tlsprobe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestProbeTLS_ConnectionRefused tests probing an unavailable endpoint.
|
||||
func TestProbeTLS_ConnectionRefused(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result := ProbeTLS(ctx, "127.0.0.1:1", 1*time.Second)
|
||||
|
||||
if result.Success {
|
||||
t.Errorf("expected Success=false for unavailable endpoint, got %v", result.Success)
|
||||
}
|
||||
if result.Error == "" {
|
||||
t.Errorf("expected Error to be set for unavailable endpoint, got empty")
|
||||
}
|
||||
// ResponseTimeMs might be 0 on very fast systems, so just check it's set
|
||||
if result.ResponseTimeMs < 0 {
|
||||
t.Errorf("expected ResponseTimeMs >= 0, got %d", result.ResponseTimeMs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProbeTLS_Success tests probing a live TLS server.
|
||||
func TestProbeTLS_Success(t *testing.T) {
|
||||
// Create a test HTTPS server with a self-signed certificate
|
||||
server := httptest.NewTLSServer(nil)
|
||||
defer server.Close()
|
||||
|
||||
// Extract the server address (remove https://)
|
||||
u := server.Listener.Addr().(*net.TCPAddr)
|
||||
address := net.JoinHostPort(u.IP.String(), fmt.Sprintf("%d", u.Port))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result := ProbeTLS(ctx, address, 5*time.Second)
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("expected Success=true, got false. Error: %s", result.Error)
|
||||
}
|
||||
if result.Fingerprint == "" {
|
||||
t.Errorf("expected Fingerprint to be set, got empty")
|
||||
}
|
||||
if result.TLSVersion == "" {
|
||||
t.Errorf("expected TLSVersion to be set, got empty")
|
||||
}
|
||||
if result.ResponseTimeMs == 0 {
|
||||
t.Errorf("expected ResponseTimeMs > 0, got 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCertFingerprint_SHA256 tests SHA-256 fingerprint computation.
|
||||
func TestCertFingerprint_SHA256(t *testing.T) {
|
||||
cert, _ := createTestCertWithKey(t, "test.example.com", "rsa")
|
||||
fp := CertFingerprint(cert)
|
||||
|
||||
if fp == "" {
|
||||
t.Errorf("expected non-empty fingerprint, got empty")
|
||||
}
|
||||
if len(fp) != 64 {
|
||||
t.Errorf("expected fingerprint length 64 (hex SHA-256), got %d", len(fp))
|
||||
}
|
||||
|
||||
// Verify it's valid hex
|
||||
for _, ch := range fp {
|
||||
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') {
|
||||
t.Errorf("expected lowercase hex fingerprint, got invalid char: %c", ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify consistency (same cert should produce same fingerprint)
|
||||
fp2 := CertFingerprint(cert)
|
||||
if fp != fp2 {
|
||||
t.Errorf("fingerprint not consistent: %s vs %s", fp, fp2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCertKeyInfo_RSA tests RSA key info extraction.
|
||||
func TestCertKeyInfo_RSA(t *testing.T) {
|
||||
cert, _ := createTestCertWithKey(t, "test.example.com", "rsa")
|
||||
|
||||
alg, size := CertKeyInfo(cert)
|
||||
|
||||
if alg != "RSA" {
|
||||
t.Errorf("expected algorithm 'RSA', got '%s'", alg)
|
||||
}
|
||||
if size != 2048 {
|
||||
t.Errorf("expected RSA key size 2048, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCertKeyInfo_ECDSA tests ECDSA key info extraction.
|
||||
func TestCertKeyInfo_ECDSA(t *testing.T) {
|
||||
cert, _ := createTestCertWithKey(t, "test.example.com", "ecdsa")
|
||||
|
||||
alg, size := CertKeyInfo(cert)
|
||||
|
||||
if alg != "ECDSA" {
|
||||
t.Errorf("expected algorithm 'ECDSA', got '%s'", alg)
|
||||
}
|
||||
if size != 256 {
|
||||
t.Errorf("expected ECDSA P-256 key size 256, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: createTestCert creates a self-signed test certificate with RSA key.
|
||||
func createTestCert(t *testing.T, cn string) *x509.Certificate {
|
||||
cert, _ := createTestCertWithKey(t, cn, "rsa")
|
||||
return cert
|
||||
}
|
||||
|
||||
// Helper: createTestCertWithKey creates a test certificate with specified key type.
|
||||
func createTestCertWithKey(t *testing.T, cn, keyType string) (*x509.Certificate, interface{}) {
|
||||
var privKey interface{}
|
||||
var pubKey interface{}
|
||||
|
||||
if keyType == "rsa" {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate RSA key: %v", err)
|
||||
}
|
||||
privKey = key
|
||||
pubKey = &key.PublicKey
|
||||
} else if keyType == "ecdsa" {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate ECDSA key: %v", err)
|
||||
}
|
||||
privKey = key
|
||||
pubKey = &key.PublicKey
|
||||
} else {
|
||||
t.Fatalf("unsupported key type: %s", keyType)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: cn,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
DNSNames: []string{cn},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, privKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
return cert, privKey
|
||||
}
|
||||
Reference in New Issue
Block a user