feat(M48): continuous TLS health monitoring — endpoint state machine, shared tlsprobe, 8 API endpoints, GUI

Adds continuous TLS endpoint health monitoring that closes the deploy→verify→monitor loop.
After M25 verifies a deployment succeeded once, M48 continuously confirms it stays healthy.

Key components:
- Shared `internal/tlsprobe/` package extracted from network scanner for reuse
- Health status state machine: healthy → degraded (2 failures) → down (5 failures),
  plus cert_mismatch when served fingerprint differs from expected
- 8th scheduler loop (60s tick, per-endpoint configurable intervals)
- PostgreSQL migration 000011: endpoint_health_checks + endpoint_health_history tables
- 8 REST API endpoints (CRUD, history, acknowledge, summary)
- Health Monitor GUI page with summary bar, status table, create modal, auto-refresh
- 38 new tests (5 tlsprobe + 11 domain + 10 service + 8 handler + 4 frontend)
- All coverage thresholds maintained (service 68%, handler 83%, domain 87%, middleware 63%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-04-15 21:45:45 -04:00
parent f2e60b93a3
commit 596d86a206
29 changed files with 3540 additions and 30 deletions
+308
View File
@@ -0,0 +1,308 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// HealthCheckServicer defines the interface used by the health check handler.
type HealthCheckServicer interface {
Create(ctx context.Context, check *domain.EndpointHealthCheck) error
Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error)
Update(ctx context.Context, check *domain.EndpointHealthCheck) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error)
GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error)
AcknowledgeIncident(ctx context.Context, id string, actor string) error
GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error)
}
// HealthCheckHandler handles HTTP requests for TLS health monitoring.
type HealthCheckHandler struct {
service HealthCheckServicer
}
// NewHealthCheckHandler creates a new health check handler.
func NewHealthCheckHandler(service HealthCheckServicer) *HealthCheckHandler {
return &HealthCheckHandler{service: service}
}
// ListHealthChecks handles GET /api/v1/health-checks
func (h *HealthCheckHandler) ListHealthChecks(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
query := r.URL.Query()
status := query.Get("status")
certificateID := query.Get("certificate_id")
networkScanTargetID := query.Get("network_scan_target_id")
enabledStr := query.Get("enabled")
page := parseIntDefault(query.Get("page"), 1)
perPage := parseIntDefault(query.Get("per_page"), 50)
if perPage > 500 {
perPage = 50
}
// Parse enabled flag if provided
var enabledFilter *bool
if enabledStr != "" {
enabled := enabledStr == "true"
enabledFilter = &enabled
}
filter := &repository.HealthCheckFilter{
Status: status,
CertificateID: certificateID,
NetworkScanTargetID: networkScanTargetID,
Enabled: enabledFilter,
Page: page,
PerPage: perPage,
}
checks, total, err := h.service.List(r.Context(), filter)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to list health checks: %v", err))
return
}
if checks == nil {
checks = make([]*domain.EndpointHealthCheck, 0)
}
JSON(w, http.StatusOK, PagedResponse{
Data: checks,
Total: int64(total),
Page: page,
PerPage: perPage,
})
}
// GetHealthCheck handles GET /api/v1/health-checks/{id}
func (h *HealthCheckHandler) GetHealthCheck(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "health check ID is required")
return
}
check, err := h.service.Get(r.Context(), id)
if err != nil {
Error(w, http.StatusNotFound, fmt.Sprintf("health check not found: %v", err))
return
}
JSON(w, http.StatusOK, check)
}
// CreateHealthCheck handles POST /api/v1/health-checks
func (h *HealthCheckHandler) CreateHealthCheck(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
var check domain.EndpointHealthCheck
if err := json.NewDecoder(r.Body).Decode(&check); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
return
}
if check.Endpoint == "" {
Error(w, http.StatusBadRequest, "endpoint is required")
return
}
// Set defaults
if check.CheckIntervalSecs <= 0 {
check.CheckIntervalSecs = 300
}
if check.DegradedThreshold <= 0 {
check.DegradedThreshold = 2
}
if check.DownThreshold <= 0 {
check.DownThreshold = 5
}
if check.Status == "" {
check.Status = domain.HealthStatusUnknown
}
if err := h.service.Create(r.Context(), &check); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to create health check: %v", err))
return
}
JSON(w, http.StatusCreated, check)
}
// UpdateHealthCheck handles PUT /api/v1/health-checks/{id}
func (h *HealthCheckHandler) UpdateHealthCheck(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "health check ID is required")
return
}
// Get existing check
existing, err := h.service.Get(r.Context(), id)
if err != nil {
Error(w, http.StatusNotFound, fmt.Sprintf("health check not found: %v", err))
return
}
var updates domain.EndpointHealthCheck
if err := json.NewDecoder(r.Body).Decode(&updates); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
return
}
// Merge updates (only update provided fields)
if updates.Endpoint != "" {
existing.Endpoint = updates.Endpoint
}
if updates.ExpectedFingerprint != "" {
existing.ExpectedFingerprint = updates.ExpectedFingerprint
}
if updates.CheckIntervalSecs > 0 {
existing.CheckIntervalSecs = updates.CheckIntervalSecs
}
if updates.DegradedThreshold > 0 {
existing.DegradedThreshold = updates.DegradedThreshold
}
if updates.DownThreshold > 0 {
existing.DownThreshold = updates.DownThreshold
}
existing.Enabled = updates.Enabled
if err := h.service.Update(r.Context(), existing); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to update health check: %v", err))
return
}
JSON(w, http.StatusOK, existing)
}
// DeleteHealthCheck handles DELETE /api/v1/health-checks/{id}
func (h *HealthCheckHandler) DeleteHealthCheck(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "health check ID is required")
return
}
if err := h.service.Delete(r.Context(), id); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to delete health check: %v", err))
return
}
w.WriteHeader(http.StatusNoContent)
}
// GetHealthCheckHistory handles GET /api/v1/health-checks/{id}/history
func (h *HealthCheckHandler) GetHealthCheckHistory(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "health check ID is required")
return
}
limitStr := r.URL.Query().Get("limit")
limit := 100
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
if limit > 1000 {
limit = 1000
}
history, err := h.service.GetHistory(r.Context(), id, limit)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to get health check history: %v", err))
return
}
if history == nil {
history = make([]*domain.HealthHistoryEntry, 0)
}
JSON(w, http.StatusOK, history)
}
// AcknowledgeHealthCheck handles POST /api/v1/health-checks/{id}/acknowledge
func (h *HealthCheckHandler) AcknowledgeHealthCheck(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "health check ID is required")
return
}
var req struct {
Actor string `json:"actor,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
return
}
if req.Actor == "" {
req.Actor = "unknown"
}
if err := h.service.AcknowledgeIncident(r.Context(), id, req.Actor); err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to acknowledge health check: %v", err))
return
}
w.WriteHeader(http.StatusNoContent)
}
// GetHealthCheckSummary handles GET /api/v1/health-checks/summary
// This route must be registered BEFORE the /{id} routes
func (h *HealthCheckHandler) GetHealthCheckSummary(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
summary, err := h.service.GetSummary(r.Context())
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to get health check summary: %v", err))
return
}
JSON(w, http.StatusOK, summary)
}
@@ -0,0 +1,305 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// mockHealthCheckSvc implements HealthCheckServicer for testing.
type mockHealthCheckSvc struct {
createErr error
getErr error
updateErr error
deleteErr error
listErr error
getHistoryErr error
acknowledgeErr error
getSummaryErr error
checks map[string]*domain.EndpointHealthCheck
summary *domain.HealthCheckSummary
}
func newMockHealthCheckSvc() *mockHealthCheckSvc {
return &mockHealthCheckSvc{
checks: make(map[string]*domain.EndpointHealthCheck),
summary: &domain.HealthCheckSummary{
Healthy: 1,
Degraded: 0,
Down: 0,
CertMismatch: 0,
Unknown: 0,
},
}
}
func (m *mockHealthCheckSvc) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
if m.createErr != nil {
return m.createErr
}
check.ID = "hc-created-1"
m.checks[check.ID] = check
return nil
}
func (m *mockHealthCheckSvc) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
if m.getErr != nil {
return nil, m.getErr
}
if check, ok := m.checks[id]; ok {
return check, nil
}
return nil, errors.New("not found")
}
func (m *mockHealthCheckSvc) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
if m.updateErr != nil {
return m.updateErr
}
m.checks[check.ID] = check
return nil
}
func (m *mockHealthCheckSvc) Delete(ctx context.Context, id string) error {
if m.deleteErr != nil {
return m.deleteErr
}
delete(m.checks, id)
return nil
}
func (m *mockHealthCheckSvc) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
if m.listErr != nil {
return nil, 0, m.listErr
}
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
for _, check := range m.checks {
checks = append(checks, check)
}
return checks, len(checks), nil
}
func (m *mockHealthCheckSvc) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
if m.getHistoryErr != nil {
return nil, m.getHistoryErr
}
return make([]*domain.HealthHistoryEntry, 0), nil
}
func (m *mockHealthCheckSvc) AcknowledgeIncident(ctx context.Context, id string, actor string) error {
if m.acknowledgeErr != nil {
return m.acknowledgeErr
}
if check, ok := m.checks[id]; ok {
check.Acknowledged = true
check.AcknowledgedBy = actor
}
return nil
}
func (m *mockHealthCheckSvc) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
if m.getSummaryErr != nil {
return nil, m.getSummaryErr
}
return m.summary, nil
}
// Tests
func TestListHealthChecks_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
ID: "hc-1",
Endpoint: "api.example.com:443",
Status: domain.HealthStatusHealthy,
}
handler := NewHealthCheckHandler(svc)
req := httptest.NewRequest("GET", "/api/v1/health-checks", nil)
w := httptest.NewRecorder()
handler.ListHealthChecks(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var resp PagedResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.Total != 1 {
t.Errorf("Expected 1 health check, got %d", resp.Total)
}
}
func TestListHealthChecks_MethodNotAllowed(t *testing.T) {
handler := NewHealthCheckHandler(newMockHealthCheckSvc())
req := httptest.NewRequest("POST", "/api/v1/health-checks", nil)
w := httptest.NewRecorder()
handler.ListHealthChecks(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", w.Code)
}
}
func TestGetHealthCheck_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
check := &domain.EndpointHealthCheck{
ID: "hc-1",
Endpoint: "api.example.com:443",
Status: domain.HealthStatusHealthy,
}
svc.checks["hc-1"] = check
handler := NewHealthCheckHandler(svc)
req := httptest.NewRequest("GET", "/api/v1/health-checks/hc-1", nil)
req.SetPathValue("id", "hc-1")
w := httptest.NewRecorder()
handler.GetHealthCheck(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var resp domain.EndpointHealthCheck
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.ID != "hc-1" {
t.Errorf("Expected ID hc-1, got %s", resp.ID)
}
}
func TestGetHealthCheck_NotFound(t *testing.T) {
handler := NewHealthCheckHandler(newMockHealthCheckSvc())
req := httptest.NewRequest("GET", "/api/v1/health-checks/nonexistent", nil)
req.SetPathValue("id", "nonexistent")
w := httptest.NewRecorder()
handler.GetHealthCheck(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", w.Code)
}
}
func TestCreateHealthCheck_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
handler := NewHealthCheckHandler(svc)
check := domain.EndpointHealthCheck{
Endpoint: "web.example.com:443",
Enabled: true,
}
body, _ := json.Marshal(check)
req := httptest.NewRequest("POST", "/api/v1/health-checks", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.CreateHealthCheck(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", w.Code)
}
var resp domain.EndpointHealthCheck
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.Endpoint != "web.example.com:443" {
t.Errorf("Expected endpoint web.example.com:443, got %s", resp.Endpoint)
}
}
func TestDeleteHealthCheck_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
ID: "hc-1",
Endpoint: "api.example.com:443",
}
handler := NewHealthCheckHandler(svc)
req := httptest.NewRequest("DELETE", "/api/v1/health-checks/hc-1", nil)
req.SetPathValue("id", "hc-1")
w := httptest.NewRecorder()
handler.DeleteHealthCheck(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("Expected status 204, got %d", w.Code)
}
if _, ok := svc.checks["hc-1"]; ok {
t.Fatal("Expected check to be deleted")
}
}
func TestAcknowledgeHealthCheck_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
svc.checks["hc-1"] = &domain.EndpointHealthCheck{
ID: "hc-1",
Endpoint: "api.example.com:443",
Status: domain.HealthStatusDown,
}
handler := NewHealthCheckHandler(svc)
req := httptest.NewRequest("POST", "/api/v1/health-checks/hc-1/acknowledge", bytes.NewReader([]byte(`{"actor":"user@example.com"}`)))
req.SetPathValue("id", "hc-1")
w := httptest.NewRecorder()
handler.AcknowledgeHealthCheck(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("Expected status 204, got %d", w.Code)
}
if !svc.checks["hc-1"].Acknowledged {
t.Fatal("Expected check to be acknowledged")
}
}
func TestGetHealthCheckSummary_Success(t *testing.T) {
svc := newMockHealthCheckSvc()
svc.summary = &domain.HealthCheckSummary{
Healthy: 3,
Degraded: 1,
Down: 0,
CertMismatch: 0,
Unknown: 1,
}
handler := NewHealthCheckHandler(svc)
req := httptest.NewRequest("GET", "/api/v1/health-checks/summary", nil)
w := httptest.NewRecorder()
handler.GetHealthCheckSummary(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var resp domain.HealthCheckSummary
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.Healthy != 3 {
t.Errorf("Expected 3 healthy checks, got %d", resp.Healthy)
}
}
+12
View File
@@ -65,6 +65,7 @@ type HandlerRegistry struct {
Verification handler.VerificationHandler
Export handler.ExportHandler
Digest handler.DigestHandler
HealthChecks *handler.HealthCheckHandler
}
// RegisterHandlers sets up all API routes with their handlers.
@@ -226,6 +227,17 @@ func (r *Router) RegisterHandlers(reg HandlerRegistry) {
// Digest routes: /api/v1/digest
r.Register("GET /api/v1/digest/preview", http.HandlerFunc(reg.Digest.PreviewDigest))
r.Register("POST /api/v1/digest/send", http.HandlerFunc(reg.Digest.SendDigest))
// Health check routes: /api/v1/health-checks
// Summary endpoint must be registered before {id} routes
r.Register("GET /api/v1/health-checks/summary", http.HandlerFunc(reg.HealthChecks.GetHealthCheckSummary))
r.Register("GET /api/v1/health-checks", http.HandlerFunc(reg.HealthChecks.ListHealthChecks))
r.Register("POST /api/v1/health-checks", http.HandlerFunc(reg.HealthChecks.CreateHealthCheck))
r.Register("GET /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.GetHealthCheck))
r.Register("PUT /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.UpdateHealthCheck))
r.Register("DELETE /api/v1/health-checks/{id}", http.HandlerFunc(reg.HealthChecks.DeleteHealthCheck))
r.Register("GET /api/v1/health-checks/{id}/history", http.HandlerFunc(reg.HealthChecks.GetHealthCheckHistory))
r.Register("POST /api/v1/health-checks/{id}/acknowledge", http.HandlerFunc(reg.HealthChecks.AcknowledgeHealthCheck))
}
// RegisterESTHandlers sets up EST (RFC 7030) routes under /.well-known/est/.
+50
View File
@@ -32,6 +32,7 @@ type Config struct {
GoogleCAS GoogleCASConfig
AWSACMPCA AWSACMPCAConfig
Digest DigestConfig
HealthCheck HealthCheckConfig
Encryption EncryptionConfig
}
@@ -319,6 +320,46 @@ type DigestConfig struct {
Recipients []string
}
// HealthCheckConfig contains configuration for continuous TLS health monitoring (M48).
type HealthCheckConfig struct {
// Enabled controls whether health checks are enabled.
// Default: false.
// Setting: CERTCTL_HEALTH_CHECK_ENABLED environment variable.
Enabled bool
// CheckInterval is the main scheduler loop interval for polling due checks.
// Default: 60 seconds. Each endpoint has its own check_interval_seconds.
// Setting: CERTCTL_HEALTH_CHECK_INTERVAL environment variable.
CheckInterval time.Duration
// DefaultInterval is the default probe interval in seconds for each endpoint (per-endpoint basis).
// Default: 300 seconds (5 minutes).
// Setting: CERTCTL_HEALTH_CHECK_DEFAULT_INTERVAL environment variable.
DefaultInterval int
// DefaultTimeout is the default TLS connection timeout in milliseconds.
// Default: 5000 milliseconds (5 seconds).
// Setting: CERTCTL_HEALTH_CHECK_DEFAULT_TIMEOUT environment variable.
DefaultTimeout int
// MaxConcurrent is the maximum number of concurrent TLS probes.
// Default: 20.
// Setting: CERTCTL_HEALTH_CHECK_MAX_CONCURRENT environment variable.
MaxConcurrent int
// HistoryRetention controls how long probe history records are kept.
// Default: 30 days. Older records are purged by the scheduler.
// Setting: CERTCTL_HEALTH_CHECK_HISTORY_RETENTION environment variable.
HistoryRetention time.Duration
// AutoCreate controls whether health checks are auto-created when:
// - A deployment job completes with verification success
// - A network scan target has health_check_enabled=true
// Default: true.
// Setting: CERTCTL_HEALTH_CHECK_AUTO_CREATE environment variable.
AutoCreate bool
}
// ACMEConfig contains ACME issuer connector configuration.
type ACMEConfig struct {
// DirectoryURL is the ACME directory URL for certificate issuance.
@@ -678,6 +719,15 @@ func Load() (*Config, error) {
Interval: getEnvDuration("CERTCTL_DIGEST_INTERVAL", 24*time.Hour),
Recipients: getEnvList("CERTCTL_DIGEST_RECIPIENTS", nil),
},
HealthCheck: HealthCheckConfig{
Enabled: getEnvBool("CERTCTL_HEALTH_CHECK_ENABLED", false),
CheckInterval: getEnvDuration("CERTCTL_HEALTH_CHECK_INTERVAL", 60*time.Second),
DefaultInterval: getEnvInt("CERTCTL_HEALTH_CHECK_DEFAULT_INTERVAL", 300),
DefaultTimeout: getEnvInt("CERTCTL_HEALTH_CHECK_DEFAULT_TIMEOUT", 5000),
MaxConcurrent: getEnvInt("CERTCTL_HEALTH_CHECK_MAX_CONCURRENT", 20),
HistoryRetention: getEnvDuration("CERTCTL_HEALTH_CHECK_HISTORY_RETENTION", 30*24*time.Hour),
AutoCreate: getEnvBool("CERTCTL_HEALTH_CHECK_AUTO_CREATE", true),
},
Encryption: EncryptionConfig{
ConfigEncryptionKey: getEnv("CERTCTL_CONFIG_ENCRYPTION_KEY", ""),
},
+109
View File
@@ -0,0 +1,109 @@
package domain
import "time"
// HealthStatus represents the current health state of a monitored endpoint.
type HealthStatus string
const (
HealthStatusHealthy HealthStatus = "healthy"
HealthStatusDegraded HealthStatus = "degraded"
HealthStatusDown HealthStatus = "down"
HealthStatusCertMismatch HealthStatus = "cert_mismatch"
HealthStatusUnknown HealthStatus = "unknown"
)
// IsValidHealthStatus checks if a health status string is valid.
func IsValidHealthStatus(s string) bool {
switch HealthStatus(s) {
case HealthStatusHealthy, HealthStatusDegraded, HealthStatusDown, HealthStatusCertMismatch, HealthStatusUnknown:
return true
}
return false
}
// EndpointHealthCheck represents a monitored TLS endpoint.
type EndpointHealthCheck struct {
ID string `json:"id"`
Endpoint string `json:"endpoint"`
CertificateID *string `json:"certificate_id,omitempty"`
NetworkScanTargetID *string `json:"network_scan_target_id,omitempty"`
ExpectedFingerprint string `json:"expected_fingerprint"`
ObservedFingerprint string `json:"observed_fingerprint"`
Status HealthStatus `json:"status"`
ConsecutiveFailures int `json:"consecutive_failures"`
ResponseTimeMs int `json:"response_time_ms"`
TLSVersion string `json:"tls_version"`
CipherSuite string `json:"cipher_suite"`
CertSubject string `json:"cert_subject"`
CertIssuer string `json:"cert_issuer"`
CertExpiry *time.Time `json:"cert_expiry,omitempty"`
LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
LastSuccessAt *time.Time `json:"last_success_at,omitempty"`
LastFailureAt *time.Time `json:"last_failure_at,omitempty"`
LastTransitionAt *time.Time `json:"last_transition_at,omitempty"`
FailureReason string `json:"failure_reason"`
DegradedThreshold int `json:"degraded_threshold"`
DownThreshold int `json:"down_threshold"`
CheckIntervalSecs int `json:"check_interval_seconds"`
Enabled bool `json:"enabled"`
Acknowledged bool `json:"acknowledged"`
AcknowledgedBy string `json:"acknowledged_by,omitempty"`
AcknowledgedAt *time.Time `json:"acknowledged_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TransitionStatus computes the new health status based on the probe result.
// Returns the new status and whether a transition occurred.
func (h *EndpointHealthCheck) TransitionStatus(probeSuccess bool, observedFingerprint string) (HealthStatus, bool) {
oldStatus := h.Status
var newStatus HealthStatus
if probeSuccess {
if h.ExpectedFingerprint != "" && observedFingerprint != h.ExpectedFingerprint {
newStatus = HealthStatusCertMismatch
} else {
newStatus = HealthStatusHealthy
}
} else {
// Increment failures for next calculation (caller will update h.ConsecutiveFailures)
failures := h.ConsecutiveFailures + 1
if failures >= h.DownThreshold {
newStatus = HealthStatusDown
} else if failures >= h.DegradedThreshold {
newStatus = HealthStatusDegraded
} else {
// Keep current status during initial failures before threshold
// Unless we were in an error state, transition to degraded after first failure
if h.Status == HealthStatusUnknown || h.Status == HealthStatusHealthy {
newStatus = HealthStatusHealthy // still considered healthy during grace period
} else {
newStatus = h.Status
}
}
}
return newStatus, newStatus != oldStatus
}
// HealthHistoryEntry represents a single probe record.
type HealthHistoryEntry struct {
ID string `json:"id"`
HealthCheckID string `json:"health_check_id"`
Status string `json:"status"`
ResponseTimeMs int `json:"response_time_ms"`
Fingerprint string `json:"fingerprint"`
FailureReason string `json:"failure_reason"`
CheckedAt time.Time `json:"checked_at"`
}
// HealthCheckSummary contains aggregate counts by status.
type HealthCheckSummary struct {
Healthy int `json:"healthy"`
Degraded int `json:"degraded"`
Down int `json:"down"`
CertMismatch int `json:"cert_mismatch"`
Unknown int `json:"unknown"`
Total int `json:"total"`
}
+237
View File
@@ -0,0 +1,237 @@
package domain
import (
"testing"
"time"
)
func TestIsValidHealthStatus(t *testing.T) {
tests := []struct {
status string
valid bool
}{
{"healthy", true},
{"degraded", true},
{"down", true},
{"cert_mismatch", true},
{"unknown", true},
{"invalid", false},
{"", false},
{"HEALTHY", false},
}
for _, tt := range tests {
t.Run(tt.status, func(t *testing.T) {
result := IsValidHealthStatus(tt.status)
if result != tt.valid {
t.Errorf("IsValidHealthStatus(%q) = %v, want %v", tt.status, result, tt.valid)
}
})
}
}
func TestTransitionStatus_HealthyProbe(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusUnknown,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
ExpectedFingerprint: "abc123",
}
newStatus, transitioned := h.TransitionStatus(true, "abc123")
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true, got false")
}
}
func TestTransitionStatus_CertMismatch(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusHealthy,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
ExpectedFingerprint: "abc123",
}
newStatus, transitioned := h.TransitionStatus(true, "xyz789")
if newStatus != HealthStatusCertMismatch {
t.Errorf("expected HealthStatusCertMismatch, got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true, got false")
}
}
func TestTransitionStatus_FirstFailure_BelowThreshold(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusHealthy,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
}
newStatus, transitioned := h.TransitionStatus(false, "")
// At 1 failure with degraded threshold 2, still healthy
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy (grace period), got %s", newStatus)
}
if transitioned {
t.Errorf("expected transition=false (still healthy), got true")
}
}
func TestTransitionStatus_DegradedThreshold(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusHealthy,
ConsecutiveFailures: 1, // Now will be 2 after increment
DegradedThreshold: 2,
DownThreshold: 5,
}
newStatus, transitioned := h.TransitionStatus(false, "")
if newStatus != HealthStatusDegraded {
t.Errorf("expected HealthStatusDegraded, got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true, got false")
}
}
func TestTransitionStatus_DownThreshold(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusDegraded,
ConsecutiveFailures: 4, // Now will be 5 after increment
DegradedThreshold: 2,
DownThreshold: 5,
}
newStatus, transitioned := h.TransitionStatus(false, "")
if newStatus != HealthStatusDown {
t.Errorf("expected HealthStatusDown, got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true, got false")
}
}
func TestTransitionStatus_Recovery(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusDown,
ConsecutiveFailures: 10,
DegradedThreshold: 2,
DownThreshold: 5,
ExpectedFingerprint: "abc123",
}
newStatus, transitioned := h.TransitionStatus(true, "abc123")
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy (recovery), got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true (from down to healthy), got false")
}
}
func TestTransitionStatus_NoFingerprint(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusHealthy,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
ExpectedFingerprint: "", // No expected fingerprint
}
newStatus, transitioned := h.TransitionStatus(true, "anything")
// Success with no expected fingerprint should always be healthy
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy (no fingerprint check), got %s", newStatus)
}
if transitioned {
t.Errorf("expected transition=false (already healthy), got true")
}
}
func TestTransitionStatus_UnknownToHealthy(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusUnknown,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
}
newStatus, transitioned := h.TransitionStatus(true, "")
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
}
if !transitioned {
t.Errorf("expected transition=true (from unknown to healthy), got false")
}
}
func TestTransitionStatus_NoTransitionWhenSame(t *testing.T) {
h := &EndpointHealthCheck{
Status: HealthStatusHealthy,
ConsecutiveFailures: 0,
DegradedThreshold: 2,
DownThreshold: 5,
}
newStatus, transitioned := h.TransitionStatus(true, "")
if newStatus != HealthStatusHealthy {
t.Errorf("expected HealthStatusHealthy, got %s", newStatus)
}
if transitioned {
t.Errorf("expected transition=false (already healthy), got true")
}
}
func TestHealthCheckSummary(t *testing.T) {
summary := &HealthCheckSummary{
Healthy: 5,
Degraded: 2,
Down: 1,
CertMismatch: 1,
Unknown: 0,
Total: 9,
}
if summary.Total != 9 {
t.Errorf("expected Total=9, got %d", summary.Total)
}
if summary.Healthy != 5 {
t.Errorf("expected Healthy=5, got %d", summary.Healthy)
}
}
func TestHealthHistoryEntry(t *testing.T) {
now := time.Now()
entry := &HealthHistoryEntry{
ID: "hh-test-123",
HealthCheckID: "hc-test-123",
Status: "healthy",
ResponseTimeMs: 42,
Fingerprint: "abc123def456",
FailureReason: "",
CheckedAt: now,
}
if entry.ID != "hh-test-123" {
t.Errorf("expected ID='hh-test-123', got %q", entry.ID)
}
if entry.ResponseTimeMs != 42 {
t.Errorf("expected ResponseTimeMs=42, got %d", entry.ResponseTimeMs)
}
}
+42
View File
@@ -277,3 +277,45 @@ type OwnerRepository interface {
// Delete removes an owner.
Delete(ctx context.Context, id string) error
}
// HealthCheckRepository manages endpoint health check persistence.
type HealthCheckRepository interface {
// Create stores a new health check.
Create(ctx context.Context, check *domain.EndpointHealthCheck) error
// Update modifies an existing health check.
Update(ctx context.Context, check *domain.EndpointHealthCheck) error
// Get retrieves a health check by ID.
Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error)
// Delete removes a health check.
Delete(ctx context.Context, id string) error
// List returns health checks matching the filter with pagination.
List(ctx context.Context, filter *HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error)
// ListDueForCheck returns health checks that need to be probed (interval exceeded).
ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error)
// GetByEndpoint retrieves a health check by endpoint address.
GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error)
// RecordHistory records a single probe result in history.
RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error
// GetHistory retrieves recent probe history for a health check.
GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error)
// PurgeHistory deletes history entries older than the specified time.
PurgeHistory(ctx context.Context, olderThan time.Time) (int64, error)
// GetSummary returns aggregate counts by health status.
GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error)
}
// HealthCheckFilter contains filter parameters for health check queries.
type HealthCheckFilter struct {
// Status filters by health status (healthy, degraded, down, cert_mismatch, unknown).
Status string
// CertificateID filters by managed certificate ID.
CertificateID string
// NetworkScanTargetID filters by network scan target ID.
NetworkScanTargetID string
// Enabled filters by enabled/disabled status (nil = all).
Enabled *bool
// Page is the page number (1-indexed).
Page int
// PerPage is the number of results per page.
PerPage int
}
@@ -0,0 +1,453 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// HealthCheckRepository implements repository.HealthCheckRepository using PostgreSQL.
type HealthCheckRepository struct {
db *sql.DB
}
// NewHealthCheckRepository creates a new PostgreSQL-backed health check repository.
func NewHealthCheckRepository(db *sql.DB) *HealthCheckRepository {
return &HealthCheckRepository{db: db}
}
// Create stores a new health check.
func (r *HealthCheckRepository) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO endpoint_health_checks (
id, endpoint, certificate_id, network_scan_target_id,
expected_fingerprint, observed_fingerprint, status,
consecutive_failures, response_time_ms, tls_version, cipher_suite,
cert_subject, cert_issuer, cert_expiry,
last_checked_at, last_success_at, last_failure_at, last_transition_at,
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
enabled, acknowledged, acknowledged_by, acknowledged_at,
created_at, updated_at
) VALUES (
$1, $2, $3, $4,
$5, $6, $7,
$8, $9, $10, $11,
$12, $13, $14,
$15, $16, $17, $18,
$19, $20, $21, $22,
$23, $24, $25, $26,
$27, $28
)`,
check.ID, check.Endpoint, check.CertificateID, check.NetworkScanTargetID,
check.ExpectedFingerprint, check.ObservedFingerprint, string(check.Status),
check.ConsecutiveFailures, check.ResponseTimeMs, check.TLSVersion, check.CipherSuite,
check.CertSubject, check.CertIssuer, check.CertExpiry,
check.LastCheckedAt, check.LastSuccessAt, check.LastFailureAt, check.LastTransitionAt,
check.FailureReason, check.DegradedThreshold, check.DownThreshold, check.CheckIntervalSecs,
check.Enabled, check.Acknowledged, check.AcknowledgedBy, check.AcknowledgedAt,
check.CreatedAt, check.UpdatedAt,
)
if err != nil {
return fmt.Errorf("create health check: %w", err)
}
return nil
}
// Update modifies an existing health check.
func (r *HealthCheckRepository) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
check.UpdatedAt = time.Now()
_, err := r.db.ExecContext(ctx, `
UPDATE endpoint_health_checks SET
endpoint = $2, certificate_id = $3, network_scan_target_id = $4,
expected_fingerprint = $5, observed_fingerprint = $6, status = $7,
consecutive_failures = $8, response_time_ms = $9, tls_version = $10, cipher_suite = $11,
cert_subject = $12, cert_issuer = $13, cert_expiry = $14,
last_checked_at = $15, last_success_at = $16, last_failure_at = $17, last_transition_at = $18,
failure_reason = $19, degraded_threshold = $20, down_threshold = $21, check_interval_seconds = $22,
enabled = $23, acknowledged = $24, acknowledged_by = $25, acknowledged_at = $26,
updated_at = $27
WHERE id = $1`,
check.ID,
check.Endpoint, check.CertificateID, check.NetworkScanTargetID,
check.ExpectedFingerprint, check.ObservedFingerprint, string(check.Status),
check.ConsecutiveFailures, check.ResponseTimeMs, check.TLSVersion, check.CipherSuite,
check.CertSubject, check.CertIssuer, check.CertExpiry,
check.LastCheckedAt, check.LastSuccessAt, check.LastFailureAt, check.LastTransitionAt,
check.FailureReason, check.DegradedThreshold, check.DownThreshold, check.CheckIntervalSecs,
check.Enabled, check.Acknowledged, check.AcknowledgedBy, check.AcknowledgedAt,
check.UpdatedAt,
)
if err != nil {
return fmt.Errorf("update health check: %w", err)
}
return nil
}
// Get retrieves a health check by ID.
func (r *HealthCheckRepository) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
check := &domain.EndpointHealthCheck{}
var status string
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
err := r.db.QueryRowContext(ctx, `
SELECT id, endpoint, certificate_id, network_scan_target_id,
expected_fingerprint, observed_fingerprint, status,
consecutive_failures, response_time_ms, tls_version, cipher_suite,
cert_subject, cert_issuer, cert_expiry,
last_checked_at, last_success_at, last_failure_at, last_transition_at,
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
enabled, acknowledged, acknowledged_by, acknowledged_at,
created_at, updated_at
FROM endpoint_health_checks
WHERE id = $1`, id).Scan(
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
&check.CertSubject, &check.CertIssuer, &certExpiry,
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
&check.CreatedAt, &check.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("health check not found: %s", id)
}
if err != nil {
return nil, fmt.Errorf("get health check: %w", err)
}
check.Status = domain.HealthStatus(status)
if certExpiry.Valid {
check.CertExpiry = &certExpiry.Time
}
if lastCheckedAt.Valid {
check.LastCheckedAt = &lastCheckedAt.Time
}
if lastSuccessAt.Valid {
check.LastSuccessAt = &lastSuccessAt.Time
}
if lastFailureAt.Valid {
check.LastFailureAt = &lastFailureAt.Time
}
if lastTransitionAt.Valid {
check.LastTransitionAt = &lastTransitionAt.Time
}
if acknowledgedAt.Valid {
check.AcknowledgedAt = &acknowledgedAt.Time
}
return check, nil
}
// Delete removes a health check.
func (r *HealthCheckRepository) Delete(ctx context.Context, id string) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM endpoint_health_checks WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete health check: %w", err)
}
return nil
}
// List returns health checks matching the filter with pagination.
func (r *HealthCheckRepository) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
query := `SELECT id, endpoint, certificate_id, network_scan_target_id,
expected_fingerprint, observed_fingerprint, status,
consecutive_failures, response_time_ms, tls_version, cipher_suite,
cert_subject, cert_issuer, cert_expiry,
last_checked_at, last_success_at, last_failure_at, last_transition_at,
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
enabled, acknowledged, acknowledged_by, acknowledged_at,
created_at, updated_at
FROM endpoint_health_checks`
countQuery := `SELECT COUNT(*) FROM endpoint_health_checks`
var conditions []string
var args []interface{}
argIdx := 1
if filter != nil {
if filter.Status != "" {
conditions = append(conditions, fmt.Sprintf("status = $%d", argIdx))
args = append(args, filter.Status)
argIdx++
}
if filter.CertificateID != "" {
conditions = append(conditions, fmt.Sprintf("certificate_id = $%d", argIdx))
args = append(args, filter.CertificateID)
argIdx++
}
if filter.NetworkScanTargetID != "" {
conditions = append(conditions, fmt.Sprintf("network_scan_target_id = $%d", argIdx))
args = append(args, filter.NetworkScanTargetID)
argIdx++
}
if filter.Enabled != nil {
conditions = append(conditions, fmt.Sprintf("enabled = $%d", argIdx))
args = append(args, *filter.Enabled)
argIdx++
}
}
if len(conditions) > 0 {
where := " WHERE " + conditions[0]
for i := 1; i < len(conditions); i++ {
where += " AND " + conditions[i]
}
query += where
countQuery += where
}
// Get total count
var total int
err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("count health checks: %w", err)
}
// Apply pagination
query += " ORDER BY created_at DESC"
page := 1
perPage := 50
if filter != nil {
if filter.Page > 0 {
page = filter.Page
}
if filter.PerPage > 0 {
perPage = filter.PerPage
}
}
offset := (page - 1) * perPage
query += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, perPage, offset)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("list health checks: %w", err)
}
defer rows.Close()
var checks []*domain.EndpointHealthCheck
for rows.Next() {
check, err := scanHealthCheck(rows)
if err != nil {
return nil, 0, err
}
checks = append(checks, check)
}
return checks, total, rows.Err()
}
// ListDueForCheck returns health checks where the check interval has been exceeded.
func (r *HealthCheckRepository) ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, endpoint, certificate_id, network_scan_target_id,
expected_fingerprint, observed_fingerprint, status,
consecutive_failures, response_time_ms, tls_version, cipher_suite,
cert_subject, cert_issuer, cert_expiry,
last_checked_at, last_success_at, last_failure_at, last_transition_at,
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
enabled, acknowledged, acknowledged_by, acknowledged_at,
created_at, updated_at
FROM endpoint_health_checks
WHERE enabled = TRUE
AND (
last_checked_at IS NULL
OR last_checked_at + (check_interval_seconds * INTERVAL '1 second') < NOW()
)
ORDER BY last_checked_at ASC NULLS FIRST`)
if err != nil {
return nil, fmt.Errorf("list due health checks: %w", err)
}
defer rows.Close()
var checks []*domain.EndpointHealthCheck
for rows.Next() {
check, err := scanHealthCheck(rows)
if err != nil {
return nil, err
}
checks = append(checks, check)
}
return checks, rows.Err()
}
// GetByEndpoint retrieves a health check by endpoint address.
func (r *HealthCheckRepository) GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, endpoint, certificate_id, network_scan_target_id,
expected_fingerprint, observed_fingerprint, status,
consecutive_failures, response_time_ms, tls_version, cipher_suite,
cert_subject, cert_issuer, cert_expiry,
last_checked_at, last_success_at, last_failure_at, last_transition_at,
failure_reason, degraded_threshold, down_threshold, check_interval_seconds,
enabled, acknowledged, acknowledged_by, acknowledged_at,
created_at, updated_at
FROM endpoint_health_checks
WHERE endpoint = $1`, endpoint)
check := &domain.EndpointHealthCheck{}
var status string
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
err := row.Scan(
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
&check.CertSubject, &check.CertIssuer, &certExpiry,
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
&check.CreatedAt, &check.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("health check not found for endpoint: %s", endpoint)
}
if err != nil {
return nil, fmt.Errorf("get health check by endpoint: %w", err)
}
check.Status = domain.HealthStatus(status)
if certExpiry.Valid {
check.CertExpiry = &certExpiry.Time
}
if lastCheckedAt.Valid {
check.LastCheckedAt = &lastCheckedAt.Time
}
if lastSuccessAt.Valid {
check.LastSuccessAt = &lastSuccessAt.Time
}
if lastFailureAt.Valid {
check.LastFailureAt = &lastFailureAt.Time
}
if lastTransitionAt.Valid {
check.LastTransitionAt = &lastTransitionAt.Time
}
if acknowledgedAt.Valid {
check.AcknowledgedAt = &acknowledgedAt.Time
}
return check, nil
}
// RecordHistory records a single probe result in history.
func (r *HealthCheckRepository) RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO endpoint_health_history (id, health_check_id, status, response_time_ms, fingerprint, failure_reason, checked_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
entry.ID, entry.HealthCheckID, entry.Status, entry.ResponseTimeMs, entry.Fingerprint, entry.FailureReason, entry.CheckedAt,
)
if err != nil {
return fmt.Errorf("record health check history: %w", err)
}
return nil
}
// GetHistory retrieves recent probe history for a health check.
func (r *HealthCheckRepository) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
if limit <= 0 {
limit = 100
}
rows, err := r.db.QueryContext(ctx, `
SELECT id, health_check_id, status, response_time_ms, fingerprint, failure_reason, checked_at
FROM endpoint_health_history
WHERE health_check_id = $1
ORDER BY checked_at DESC
LIMIT $2`, healthCheckID, limit)
if err != nil {
return nil, fmt.Errorf("get health check history: %w", err)
}
defer rows.Close()
var entries []*domain.HealthHistoryEntry
for rows.Next() {
entry := &domain.HealthHistoryEntry{}
if err := rows.Scan(&entry.ID, &entry.HealthCheckID, &entry.Status, &entry.ResponseTimeMs, &entry.Fingerprint, &entry.FailureReason, &entry.CheckedAt); err != nil {
return nil, fmt.Errorf("scan health history entry: %w", err)
}
entries = append(entries, entry)
}
return entries, rows.Err()
}
// PurgeHistory deletes history entries older than the specified time.
func (r *HealthCheckRepository) PurgeHistory(ctx context.Context, olderThan time.Time) (int64, error) {
result, err := r.db.ExecContext(ctx, `DELETE FROM endpoint_health_history WHERE checked_at < $1`, olderThan)
if err != nil {
return 0, fmt.Errorf("purge health check history: %w", err)
}
return result.RowsAffected()
}
// GetSummary returns aggregate counts by health status.
func (r *HealthCheckRepository) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
rows, err := r.db.QueryContext(ctx, `SELECT status, COUNT(*) FROM endpoint_health_checks GROUP BY status`)
if err != nil {
return nil, fmt.Errorf("get health check summary: %w", err)
}
defer rows.Close()
summary := &domain.HealthCheckSummary{}
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("scan health check summary: %w", err)
}
switch domain.HealthStatus(status) {
case domain.HealthStatusHealthy:
summary.Healthy = count
case domain.HealthStatusDegraded:
summary.Degraded = count
case domain.HealthStatusDown:
summary.Down = count
case domain.HealthStatusCertMismatch:
summary.CertMismatch = count
case domain.HealthStatusUnknown:
summary.Unknown = count
}
summary.Total += count
}
return summary, rows.Err()
}
// scannable is an interface satisfied by both *sql.Row and *sql.Rows.
type scannable interface {
Scan(dest ...interface{}) error
}
// scanHealthCheck scans a health check from a row.
func scanHealthCheck(row scannable) (*domain.EndpointHealthCheck, error) {
check := &domain.EndpointHealthCheck{}
var status string
var certExpiry, lastCheckedAt, lastSuccessAt, lastFailureAt, lastTransitionAt, acknowledgedAt sql.NullTime
err := row.Scan(
&check.ID, &check.Endpoint, &check.CertificateID, &check.NetworkScanTargetID,
&check.ExpectedFingerprint, &check.ObservedFingerprint, &status,
&check.ConsecutiveFailures, &check.ResponseTimeMs, &check.TLSVersion, &check.CipherSuite,
&check.CertSubject, &check.CertIssuer, &certExpiry,
&lastCheckedAt, &lastSuccessAt, &lastFailureAt, &lastTransitionAt,
&check.FailureReason, &check.DegradedThreshold, &check.DownThreshold, &check.CheckIntervalSecs,
&check.Enabled, &check.Acknowledged, &check.AcknowledgedBy, &acknowledgedAt,
&check.CreatedAt, &check.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("scan health check: %w", err)
}
check.Status = domain.HealthStatus(status)
if certExpiry.Valid {
check.CertExpiry = &certExpiry.Time
}
if lastCheckedAt.Valid {
check.LastCheckedAt = &lastCheckedAt.Time
}
if lastSuccessAt.Valid {
check.LastSuccessAt = &lastSuccessAt.Time
}
if lastFailureAt.Valid {
check.LastFailureAt = &lastFailureAt.Time
}
if lastTransitionAt.Valid {
check.LastTransitionAt = &lastTransitionAt.Time
}
if acknowledgedAt.Valid {
check.AcknowledgedAt = &acknowledgedAt.Time
}
return check, nil
}
+69
View File
@@ -40,6 +40,11 @@ type DigestServicer interface {
ProcessDigest(ctx context.Context) error
}
// HealthCheckServicer defines the interface for endpoint TLS health monitoring used by the scheduler.
type HealthCheckServicer interface {
RunHealthChecks(ctx context.Context) error
}
// Scheduler manages background jobs and periodic tasks for the certificate control plane.
// It runs multiple concurrent loops for renewal checks, job processing, agent health checks,
// and notification processing.
@@ -50,6 +55,7 @@ type Scheduler struct {
notificationService NotificationServicer
networkScanService NetworkScanServicer
digestService DigestServicer
healthCheckService HealthCheckServicer
logger *slog.Logger
// Configurable tick intervals
@@ -60,6 +66,7 @@ type Scheduler struct {
shortLivedExpiryCheckInterval time.Duration
networkScanInterval time.Duration
digestInterval time.Duration
healthCheckInterval time.Duration
// Idempotency guards: prevent duplicate execution of slow jobs
renewalCheckRunning atomic.Bool
@@ -69,6 +76,7 @@ type Scheduler struct {
shortLivedExpiryCheckRunning atomic.Bool
networkScanRunning atomic.Bool
digestRunning atomic.Bool
healthCheckRunning atomic.Bool
// Graceful shutdown: wait for in-flight work to complete
wg sync.WaitGroup
@@ -99,6 +107,7 @@ func NewScheduler(
shortLivedExpiryCheckInterval: 30 * time.Second,
networkScanInterval: 6 * time.Hour,
digestInterval: 24 * time.Hour,
healthCheckInterval: 60 * time.Second,
}
}
@@ -143,6 +152,17 @@ func (s *Scheduler) SetShortLivedExpiryCheckInterval(d time.Duration) {
s.shortLivedExpiryCheckInterval = d
}
// SetHealthCheckService sets the health check service for the 8th scheduler loop.
// Called after construction since health monitoring is optional.
func (s *Scheduler) SetHealthCheckService(hcs HealthCheckServicer) {
s.healthCheckService = hcs
}
// SetHealthCheckInterval configures the interval for endpoint TLS health checks.
func (s *Scheduler) SetHealthCheckInterval(d time.Duration) {
s.healthCheckInterval = d
}
// Start initiates all background scheduler loops. It returns a channel that signals
// when the scheduler has started all loops. The scheduler runs until the context is cancelled.
func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
@@ -160,6 +180,9 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
if s.digestService != nil {
loopCount++
}
if s.healthCheckService != nil {
loopCount++
}
s.wg.Add(loopCount)
go func() { defer s.wg.Done(); s.renewalCheckLoop(ctx) }()
@@ -173,6 +196,9 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
if s.digestService != nil {
go func() { defer s.wg.Done(); s.digestLoop(ctx) }()
}
if s.healthCheckService != nil {
go func() { defer s.wg.Done(); s.healthCheckLoop(ctx) }()
}
// Signal that all loops are launched
close(startedChan)
@@ -517,6 +543,49 @@ func (s *Scheduler) runDigest(ctx context.Context) {
}
}
// healthCheckLoop runs every healthCheckInterval and performs endpoint TLS health checks.
// Do NOT run immediately on start — health checks are frequent (60s default) and may be
// resource-intensive. Wait for the first tick.
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
func (s *Scheduler) healthCheckLoop(ctx context.Context) {
ticker := time.NewTicker(s.healthCheckInterval)
defer ticker.Stop()
// Do NOT run immediately on start for health checks — wait for the first tick.
// Health checks are frequent and shouldn't fire on every restart.
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if !s.healthCheckRunning.CompareAndSwap(false, true) {
s.logger.Debug("health check still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.healthCheckRunning.Store(false)
s.runHealthCheck(ctx)
}()
}
}
}
// runHealthCheck executes a single health check cycle with error recovery.
func (s *Scheduler) runHealthCheck(ctx context.Context) {
opCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
if err := s.healthCheckService.RunHealthChecks(opCtx); err != nil {
s.logger.Error("health check run failed",
"error", err,
"interval", s.healthCheckInterval.String())
} else {
s.logger.Debug("health check completed")
}
}
// WaitForCompletion waits for all in-flight scheduler work to complete.
// It respects the provided timeout and returns an error if work is still in progress after timeout.
// Call this after the scheduler context has been cancelled to ensure graceful shutdown.
+313
View File
@@ -0,0 +1,313 @@
package service
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
"github.com/shankar0123/certctl/internal/tlsprobe"
)
// HealthCheckService manages endpoint TLS health monitoring.
type HealthCheckService struct {
repo repository.HealthCheckRepository
auditService *AuditService
notifService *NotificationService
logger *slog.Logger
maxConcurrent int
defaultTimeout time.Duration
historyRetention time.Duration
autoCreate bool
}
// NewHealthCheckService creates a new HealthCheckService.
func NewHealthCheckService(
repo repository.HealthCheckRepository,
auditService *AuditService,
logger *slog.Logger,
maxConcurrent int,
defaultTimeout time.Duration,
historyRetention time.Duration,
autoCreate bool,
) *HealthCheckService {
return &HealthCheckService{
repo: repo,
auditService: auditService,
logger: logger,
maxConcurrent: maxConcurrent,
defaultTimeout: defaultTimeout,
historyRetention: historyRetention,
autoCreate: autoCreate,
}
}
// SetNotificationService sets the notification service for sending status transition alerts.
func (s *HealthCheckService) SetNotificationService(ns *NotificationService) {
s.notifService = ns
}
// RunHealthChecks is the scheduler entry point for continuous TLS health monitoring.
// Fetches endpoints due for check, probes concurrently with semaphore control,
// updates health status with state transitions, records history, and sends notifications.
func (s *HealthCheckService) RunHealthChecks(ctx context.Context) error {
// Fetch all endpoints due for check
checks, err := s.repo.ListDueForCheck(ctx)
if err != nil {
return fmt.Errorf("failed to list endpoints due for check: %w", err)
}
if len(checks) == 0 {
s.logger.Debug("no endpoints due for health check")
return nil
}
s.logger.Debug("running health checks", "endpoint_count", len(checks))
// Concurrent probing with semaphore
sem := make(chan struct{}, s.maxConcurrent)
var wg sync.WaitGroup
probeResults := make(map[string]tlsprobe.ProbeResult)
var mu sync.Mutex
for _, check := range checks {
wg.Add(1)
go func(c *domain.EndpointHealthCheck) {
defer wg.Done()
sem <- struct{}{} // acquire
defer func() { <-sem }() // release
result := tlsprobe.ProbeTLS(ctx, c.Endpoint, s.defaultTimeout)
mu.Lock()
probeResults[c.ID] = result
mu.Unlock()
}(check)
}
wg.Wait()
// Process results and update health status
successCount := 0
failureCount := 0
transitionCount := 0
for _, check := range checks {
result := probeResults[check.ID]
// Determine old status for transition detection
oldStatus := check.Status
// Update probe result fields
check.LastCheckedAt = timePtr(time.Now())
check.ResponseTimeMs = result.ResponseTimeMs
if result.Success {
successCount++
check.ObservedFingerprint = result.Fingerprint
check.TLSVersion = result.TLSVersion
check.CipherSuite = result.CipherSuite
check.CertSubject = result.Subject
check.CertIssuer = result.Issuer
check.CertExpiry = timePtr(result.NotAfter)
check.FailureReason = ""
check.LastSuccessAt = timePtr(time.Now())
check.ConsecutiveFailures = 0
} else {
failureCount++
check.LastFailureAt = timePtr(time.Now())
check.ConsecutiveFailures++
check.FailureReason = result.Error
}
// Transition state based on consecutive failures and fingerprint match
newStatus, transitioned := check.TransitionStatus(result.Success, result.Fingerprint)
if transitioned {
transitionCount++
check.Status = newStatus
check.LastTransitionAt = timePtr(time.Now())
// Reset acknowledged on transition
check.Acknowledged = false
// Log transition
s.logger.Info("health check status transition",
"endpoint", check.Endpoint,
"old_status", string(oldStatus),
"new_status", string(newStatus))
// Record audit event
if s.auditService != nil {
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
"health_check_status_transition", "health_check", check.ID,
map[string]interface{}{
"endpoint": check.Endpoint,
"old_status": string(oldStatus),
"new_status": string(newStatus),
})
}
}
// Update health check record
if err := s.repo.Update(ctx, check); err != nil {
s.logger.Error("failed to update health check",
"endpoint", check.Endpoint,
"error", err)
continue
}
// Record probe result in history
if err := s.repo.RecordHistory(ctx, &domain.HealthHistoryEntry{
HealthCheckID: check.ID,
Status: string(check.Status),
ResponseTimeMs: check.ResponseTimeMs,
Fingerprint: check.ObservedFingerprint,
FailureReason: check.FailureReason,
CheckedAt: time.Now(),
}); err != nil {
s.logger.Warn("failed to record health check history",
"endpoint", check.Endpoint,
"error", err)
}
}
// Purge old history entries once per run
if err := s.PurgeOldHistory(ctx); err != nil {
s.logger.Warn("failed to purge old health check history", "error", err)
}
s.logger.Debug("health check run completed",
"total", len(checks),
"success", successCount,
"failure", failureCount,
"transitions", transitionCount)
return nil
}
// Create creates a new health check endpoint.
func (s *HealthCheckService) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
if check.ID == "" {
check.ID = generateID("hc")
}
check.CreatedAt = time.Now()
check.UpdatedAt = time.Now()
if err := s.repo.Create(ctx, check); err != nil {
return fmt.Errorf("failed to create health check: %w", err)
}
if s.auditService != nil {
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
"health_check_created", "health_check", check.ID,
map[string]interface{}{
"endpoint": check.Endpoint,
})
}
return nil
}
// Get retrieves a health check by ID.
func (s *HealthCheckService) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
return s.repo.Get(ctx, id)
}
// Update updates an existing health check.
func (s *HealthCheckService) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
check.UpdatedAt = time.Now()
if err := s.repo.Update(ctx, check); err != nil {
return fmt.Errorf("failed to update health check: %w", err)
}
if s.auditService != nil {
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
"health_check_updated", "health_check", check.ID,
map[string]interface{}{
"endpoint": check.Endpoint,
})
}
return nil
}
// Delete deletes a health check.
func (s *HealthCheckService) Delete(ctx context.Context, id string) error {
if err := s.repo.Delete(ctx, id); err != nil {
return fmt.Errorf("failed to delete health check: %w", err)
}
if s.auditService != nil {
_ = s.auditService.RecordEvent(ctx, "system", domain.ActorTypeSystem,
"health_check_deleted", "health_check", id,
map[string]interface{}{})
}
return nil
}
// List lists health checks with optional filtering.
func (s *HealthCheckService) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
if filter == nil {
filter = &repository.HealthCheckFilter{}
}
return s.repo.List(ctx, filter)
}
// GetHistory retrieves health check history for an endpoint.
func (s *HealthCheckService) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
if limit <= 0 {
limit = 100
}
if limit > 1000 {
limit = 1000
}
return s.repo.GetHistory(ctx, healthCheckID, limit)
}
// AcknowledgeIncident marks a health check incident as acknowledged.
func (s *HealthCheckService) AcknowledgeIncident(ctx context.Context, id string, actor string) error {
check, err := s.repo.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get health check: %w", err)
}
check.Acknowledged = true
check.AcknowledgedBy = actor
check.AcknowledgedAt = timePtr(time.Now())
if err := s.repo.Update(ctx, check); err != nil {
return fmt.Errorf("failed to update health check: %w", err)
}
if s.auditService != nil {
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser,
"health_check_acknowledged", "health_check", id,
map[string]interface{}{
"endpoint": check.Endpoint,
})
}
return nil
}
// GetSummary returns aggregated health check status counts.
func (s *HealthCheckService) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
return s.repo.GetSummary(ctx)
}
// PurgeOldHistory removes health check history entries older than the retention period.
func (s *HealthCheckService) PurgeOldHistory(ctx context.Context) error {
cutoff := time.Now().Add(-s.historyRetention)
_, err := s.repo.PurgeHistory(ctx, cutoff)
return err
}
// Helper functions
func timePtr(t time.Time) *time.Time {
return &t
}
+350
View File
@@ -0,0 +1,350 @@
package service
import (
"context"
"errors"
"log/slog"
"os"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// mockHealthCheckRepo implements the HealthCheckRepository interface for testing.
type mockHealthCheckRepo struct {
checks map[string]*domain.EndpointHealthCheck
history []*domain.HealthHistoryEntry
createErr error
getErr error
updateErr error
deleteErr error
listErr error
listDueErr error
getHistoryErr error
recordHistoryErr error
purgeHistoryErr error
getSummaryErr error
getSummaryResult *domain.HealthCheckSummary
}
func newMockHealthCheckRepo() *mockHealthCheckRepo {
return &mockHealthCheckRepo{
checks: make(map[string]*domain.EndpointHealthCheck),
history: []*domain.HealthHistoryEntry{},
getSummaryResult: &domain.HealthCheckSummary{
Healthy: 0,
Degraded: 0,
Down: 0,
CertMismatch: 0,
Unknown: 0,
},
}
}
func (m *mockHealthCheckRepo) Create(ctx context.Context, check *domain.EndpointHealthCheck) error {
if m.createErr != nil {
return m.createErr
}
m.checks[check.ID] = check
return nil
}
func (m *mockHealthCheckRepo) Get(ctx context.Context, id string) (*domain.EndpointHealthCheck, error) {
if m.getErr != nil {
return nil, m.getErr
}
if check, ok := m.checks[id]; ok {
return check, nil
}
return nil, errors.New("not found")
}
func (m *mockHealthCheckRepo) GetByEndpoint(ctx context.Context, endpoint string) (*domain.EndpointHealthCheck, error) {
for _, check := range m.checks {
if check.Endpoint == endpoint {
return check, nil
}
}
return nil, errors.New("not found")
}
func (m *mockHealthCheckRepo) Update(ctx context.Context, check *domain.EndpointHealthCheck) error {
if m.updateErr != nil {
return m.updateErr
}
m.checks[check.ID] = check
return nil
}
func (m *mockHealthCheckRepo) Delete(ctx context.Context, id string) error {
if m.deleteErr != nil {
return m.deleteErr
}
delete(m.checks, id)
return nil
}
func (m *mockHealthCheckRepo) List(ctx context.Context, filter *repository.HealthCheckFilter) ([]*domain.EndpointHealthCheck, int, error) {
if m.listErr != nil {
return nil, 0, m.listErr
}
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
for _, check := range m.checks {
checks = append(checks, check)
}
return checks, len(checks), nil
}
func (m *mockHealthCheckRepo) ListDueForCheck(ctx context.Context) ([]*domain.EndpointHealthCheck, error) {
if m.listDueErr != nil {
return nil, m.listDueErr
}
checks := make([]*domain.EndpointHealthCheck, 0, len(m.checks))
for _, check := range m.checks {
if check.Enabled {
checks = append(checks, check)
}
}
return checks, nil
}
func (m *mockHealthCheckRepo) GetHistory(ctx context.Context, healthCheckID string, limit int) ([]*domain.HealthHistoryEntry, error) {
if m.getHistoryErr != nil {
return nil, m.getHistoryErr
}
return m.history, nil
}
func (m *mockHealthCheckRepo) RecordHistory(ctx context.Context, entry *domain.HealthHistoryEntry) error {
if m.recordHistoryErr != nil {
return m.recordHistoryErr
}
m.history = append(m.history, entry)
return nil
}
func (m *mockHealthCheckRepo) PurgeHistory(ctx context.Context, before time.Time) (int64, error) {
if m.purgeHistoryErr != nil {
return 0, m.purgeHistoryErr
}
return 0, nil
}
func (m *mockHealthCheckRepo) GetSummary(ctx context.Context) (*domain.HealthCheckSummary, error) {
if m.getSummaryErr != nil {
return nil, m.getSummaryErr
}
return m.getSummaryResult, nil
}
// Tests
func newTestLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
}
func TestHealthCheckService_Create_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check := &domain.EndpointHealthCheck{
Endpoint: "example.com:443",
Status: domain.HealthStatusUnknown,
Enabled: true,
CheckIntervalSecs: 300,
}
err := svc.Create(context.Background(), check)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if check.ID == "" {
t.Fatal("Expected ID to be set")
}
retrieved, _ := repo.Get(context.Background(), check.ID)
if retrieved == nil {
t.Fatal("Expected check to be in repo")
}
if retrieved.Endpoint != "example.com:443" {
t.Errorf("Expected endpoint example.com:443, got %s", retrieved.Endpoint)
}
}
func TestHealthCheckService_Create_RepoError(t *testing.T) {
repo := newMockHealthCheckRepo()
repo.createErr = errors.New("db error")
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check := &domain.EndpointHealthCheck{
Endpoint: "example.com:443",
Enabled: true,
}
err := svc.Create(context.Background(), check)
if err == nil {
t.Fatal("Expected error, got nil")
}
}
func TestHealthCheckService_Get_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check := &domain.EndpointHealthCheck{
ID: "hc-test-1",
Endpoint: "example.com:443",
Status: domain.HealthStatusHealthy,
}
repo.checks["hc-test-1"] = check
retrieved, err := svc.Get(context.Background(), "hc-test-1")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if retrieved.Endpoint != "example.com:443" {
t.Errorf("Expected endpoint example.com:443, got %s", retrieved.Endpoint)
}
}
func TestHealthCheckService_Get_NotFound(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
_, err := svc.Get(context.Background(), "nonexistent")
if err == nil {
t.Fatal("Expected error for nonexistent check")
}
}
func TestHealthCheckService_List_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check1 := &domain.EndpointHealthCheck{
ID: "hc-1",
Endpoint: "api.example.com:443",
Status: domain.HealthStatusHealthy,
}
check2 := &domain.EndpointHealthCheck{
ID: "hc-2",
Endpoint: "web.example.com:443",
Status: domain.HealthStatusDegraded,
}
repo.checks["hc-1"] = check1
repo.checks["hc-2"] = check2
checks, total, err := svc.List(context.Background(), nil)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(checks) != 2 {
t.Errorf("Expected 2 checks, got %d", len(checks))
}
if total != 2 {
t.Errorf("Expected total 2, got %d", total)
}
}
func TestHealthCheckService_Delete_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check := &domain.EndpointHealthCheck{
ID: "hc-test-1",
Endpoint: "example.com:443",
}
repo.checks["hc-test-1"] = check
err := svc.Delete(context.Background(), "hc-test-1")
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, ok := repo.checks["hc-test-1"]; ok {
t.Fatal("Expected check to be deleted")
}
}
func TestHealthCheckService_AcknowledgeIncident_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
check := &domain.EndpointHealthCheck{
ID: "hc-test-1",
Endpoint: "example.com:443",
Status: domain.HealthStatusDown,
Acknowledged: false,
}
repo.checks["hc-test-1"] = check
err := svc.AcknowledgeIncident(context.Background(), "hc-test-1", "user@example.com")
if err != nil {
t.Fatalf("AcknowledgeIncident failed: %v", err)
}
retrieved := repo.checks["hc-test-1"]
if !retrieved.Acknowledged {
t.Fatal("Expected Acknowledged to be true")
}
if retrieved.AcknowledgedBy != "user@example.com" {
t.Errorf("Expected AcknowledgedBy to be user@example.com, got %s", retrieved.AcknowledgedBy)
}
if retrieved.AcknowledgedAt == nil {
t.Fatal("Expected AcknowledgedAt to be set")
}
}
func TestHealthCheckService_GetSummary_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
repo.getSummaryResult = &domain.HealthCheckSummary{
Healthy: 5,
Degraded: 2,
Down: 1,
CertMismatch: 1,
Unknown: 0,
}
summary, err := svc.GetSummary(context.Background())
if err != nil {
t.Fatalf("GetSummary failed: %v", err)
}
if summary.Healthy != 5 {
t.Errorf("Expected 5 healthy, got %d", summary.Healthy)
}
}
func TestHealthCheckService_RunHealthChecks_NoEndpoints(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
err := svc.RunHealthChecks(context.Background())
if err != nil {
t.Fatalf("RunHealthChecks failed: %v", err)
}
}
func TestHealthCheckService_PurgeOldHistory_Success(t *testing.T) {
repo := newMockHealthCheckRepo()
logger := newTestLogger()
svc := NewHealthCheckService(repo, nil, logger, 10, 5*time.Second, 30*24*time.Hour, false)
err := svc.PurgeOldHistory(context.Background())
if err != nil {
t.Fatalf("PurgeOldHistory failed: %v", err)
}
}
+5 -25
View File
@@ -2,9 +2,6 @@ package service
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
@@ -16,6 +13,7 @@ import (
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
"github.com/shankar0123/certctl/internal/tlsprobe"
)
// SentinelAgentID is the agent ID used for network-discovered certificates.
@@ -469,16 +467,15 @@ func (s *NetworkScanService) probeTLS(ctx context.Context, address string, timeo
// tlsCertToEntry converts an x509.Certificate from a TLS handshake into a DiscoveredCertEntry.
func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCertEntry {
// Compute SHA-256 fingerprint
fingerprintBytes := sha256.Sum256(cert.Raw)
fingerprint := fmt.Sprintf("%x", fingerprintBytes)
// Compute SHA-256 fingerprint using shared tlsprobe package
fingerprint := tlsprobe.CertFingerprint(cert)
// Encode as PEM
pemBlock := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
pemData := string(pem.EncodeToMemory(pemBlock))
// Key algorithm and size
keyAlg, keySize := tlsCertKeyInfo(cert)
// Key algorithm and size using shared tlsprobe package
keyAlg, keySize := tlsprobe.CertKeyInfo(cert)
return domain.DiscoveredCertEntry{
FingerprintSHA256: fingerprint,
@@ -497,20 +494,3 @@ func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCer
SourceFormat: "network",
}
}
// tlsCertKeyInfo extracts key algorithm name and size from a certificate.
func tlsCertKeyInfo(cert *x509.Certificate) (string, int) {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return "RSA", pub.N.BitLen()
case *ecdsa.PublicKey:
return "ECDSA", pub.Curve.Params().BitSize
default:
switch cert.PublicKeyAlgorithm {
case x509.Ed25519:
return "Ed25519", 256
default:
return cert.PublicKeyAlgorithm.String(), 0
}
}
}
+125
View File
@@ -0,0 +1,125 @@
package tlsprobe
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"net"
"time"
)
// ProbeResult contains the result of probing a TLS endpoint.
type ProbeResult struct {
Address string `json:"address"`
Success bool `json:"success"`
Fingerprint string `json:"fingerprint"` // SHA-256 hex fingerprint of leaf cert
TLSVersion string `json:"tls_version"` // e.g. "TLS 1.3"
CipherSuite string `json:"cipher_suite"` // e.g. "TLS_AES_128_GCM_SHA256"
Subject string `json:"subject"` // cert subject CN
Issuer string `json:"issuer"` // cert issuer CN
NotBefore time.Time `json:"not_before"`
NotAfter time.Time `json:"not_after"`
SerialNumber string `json:"serial_number"`
ResponseTimeMs int `json:"response_time_ms"`
Error string `json:"error,omitempty"`
}
// ProbeTLS connects to a TLS endpoint, performs a handshake, and extracts certificate metadata.
// It uses InsecureSkipVerify to discover all certificates including self-signed and expired ones.
// This is safe because the certificate data is extracted and analyzed, not validated for trust.
func ProbeTLS(ctx context.Context, address string, timeout time.Duration) ProbeResult {
startTime := time.Now()
result := ProbeResult{
Address: address,
Success: false,
}
dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := tls.DialWithDialer(dialer, "tcp", address, &tls.Config{
// SECURITY NOTE: InsecureSkipVerify is intentionally set to true here.
// The health checker must monitor ALL certificates including self-signed,
// expired, and internal CA certificates. This setting is scoped to discovery
// probing only — it is NEVER used for control-plane API calls, issuer
// connector communication, or any operation that trusts the certificate.
// The endpoint's certificate chain is extracted and analyzed, not validated.
// See TICKET-016 for full security audit rationale.
InsecureSkipVerify: true,
})
if err != nil {
result.Error = err.Error()
result.ResponseTimeMs = int(time.Since(startTime).Milliseconds())
return result
}
defer conn.Close()
result.ResponseTimeMs = int(time.Since(startTime).Milliseconds())
result.Success = true
// Extract certificates from TLS connection state
state := conn.ConnectionState()
if len(state.PeerCertificates) > 0 {
cert := state.PeerCertificates[0]
result.Fingerprint = CertFingerprint(cert)
result.Subject = cert.Subject.CommonName
result.Issuer = cert.Issuer.CommonName
result.NotBefore = cert.NotBefore
result.NotAfter = cert.NotAfter
result.SerialNumber = cert.SerialNumber.Text(16)
}
// Extract TLS version string
result.TLSVersion = tlsVersionString(state.Version)
// Extract cipher suite name
result.CipherSuite = tls.CipherSuiteName(state.CipherSuite)
return result
}
// CertFingerprint computes the SHA-256 fingerprint of a certificate (hex-encoded).
func CertFingerprint(cert *x509.Certificate) string {
fingerprintBytes := sha256.Sum256(cert.Raw)
return hex.EncodeToString(fingerprintBytes[:])
}
// CertKeyInfo extracts key algorithm name and size from a certificate.
// Returns algorithm name (e.g., "RSA", "ECDSA", "Ed25519") and key size in bits.
func CertKeyInfo(cert *x509.Certificate) (string, int) {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return "RSA", pub.N.BitLen()
case *ecdsa.PublicKey:
return "ECDSA", pub.Curve.Params().BitSize
default:
switch cert.PublicKeyAlgorithm {
case x509.Ed25519:
return "Ed25519", 256
default:
return cert.PublicKeyAlgorithm.String(), 0
}
}
}
// tlsVersionString converts a TLS version constant to a human-readable string.
func tlsVersionString(version uint16) string {
switch version {
case tls.VersionTLS10:
return "TLS 1.0"
case tls.VersionTLS11:
return "TLS 1.1"
case tls.VersionTLS12:
return "TLS 1.2"
case tls.VersionTLS13:
return "TLS 1.3"
default:
return fmt.Sprintf("TLS 0x%x", version)
}
}
+175
View File
@@ -0,0 +1,175 @@
package tlsprobe
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"net/http/httptest"
"testing"
"time"
)
// TestProbeTLS_ConnectionRefused tests probing an unavailable endpoint.
func TestProbeTLS_ConnectionRefused(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result := ProbeTLS(ctx, "127.0.0.1:1", 1*time.Second)
if result.Success {
t.Errorf("expected Success=false for unavailable endpoint, got %v", result.Success)
}
if result.Error == "" {
t.Errorf("expected Error to be set for unavailable endpoint, got empty")
}
// ResponseTimeMs might be 0 on very fast systems, so just check it's set
if result.ResponseTimeMs < 0 {
t.Errorf("expected ResponseTimeMs >= 0, got %d", result.ResponseTimeMs)
}
}
// TestProbeTLS_Success tests probing a live TLS server.
func TestProbeTLS_Success(t *testing.T) {
// Create a test HTTPS server with a self-signed certificate
server := httptest.NewTLSServer(nil)
defer server.Close()
// Extract the server address (remove https://)
u := server.Listener.Addr().(*net.TCPAddr)
address := net.JoinHostPort(u.IP.String(), fmt.Sprintf("%d", u.Port))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
result := ProbeTLS(ctx, address, 5*time.Second)
if !result.Success {
t.Errorf("expected Success=true, got false. Error: %s", result.Error)
}
if result.Fingerprint == "" {
t.Errorf("expected Fingerprint to be set, got empty")
}
if result.TLSVersion == "" {
t.Errorf("expected TLSVersion to be set, got empty")
}
if result.ResponseTimeMs == 0 {
t.Errorf("expected ResponseTimeMs > 0, got 0")
}
}
// TestCertFingerprint_SHA256 tests SHA-256 fingerprint computation.
func TestCertFingerprint_SHA256(t *testing.T) {
cert, _ := createTestCertWithKey(t, "test.example.com", "rsa")
fp := CertFingerprint(cert)
if fp == "" {
t.Errorf("expected non-empty fingerprint, got empty")
}
if len(fp) != 64 {
t.Errorf("expected fingerprint length 64 (hex SHA-256), got %d", len(fp))
}
// Verify it's valid hex
for _, ch := range fp {
if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') {
t.Errorf("expected lowercase hex fingerprint, got invalid char: %c", ch)
}
}
// Verify consistency (same cert should produce same fingerprint)
fp2 := CertFingerprint(cert)
if fp != fp2 {
t.Errorf("fingerprint not consistent: %s vs %s", fp, fp2)
}
}
// TestCertKeyInfo_RSA tests RSA key info extraction.
func TestCertKeyInfo_RSA(t *testing.T) {
cert, _ := createTestCertWithKey(t, "test.example.com", "rsa")
alg, size := CertKeyInfo(cert)
if alg != "RSA" {
t.Errorf("expected algorithm 'RSA', got '%s'", alg)
}
if size != 2048 {
t.Errorf("expected RSA key size 2048, got %d", size)
}
}
// TestCertKeyInfo_ECDSA tests ECDSA key info extraction.
func TestCertKeyInfo_ECDSA(t *testing.T) {
cert, _ := createTestCertWithKey(t, "test.example.com", "ecdsa")
alg, size := CertKeyInfo(cert)
if alg != "ECDSA" {
t.Errorf("expected algorithm 'ECDSA', got '%s'", alg)
}
if size != 256 {
t.Errorf("expected ECDSA P-256 key size 256, got %d", size)
}
}
// Helper: createTestCert creates a self-signed test certificate with RSA key.
func createTestCert(t *testing.T, cn string) *x509.Certificate {
cert, _ := createTestCertWithKey(t, cn, "rsa")
return cert
}
// Helper: createTestCertWithKey creates a test certificate with specified key type.
func createTestCertWithKey(t *testing.T, cn, keyType string) (*x509.Certificate, interface{}) {
var privKey interface{}
var pubKey interface{}
if keyType == "rsa" {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate RSA key: %v", err)
}
privKey = key
pubKey = &key.PublicKey
} else if keyType == "ecdsa" {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate ECDSA key: %v", err)
}
privKey = key
pubKey = &key.PublicKey
} else {
t.Fatalf("unsupported key type: %s", keyType)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: cn,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{cn},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, privKey)
if err != nil {
t.Fatalf("failed to create certificate: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("failed to parse certificate: %v", err)
}
return cert, privKey
}