fix(quality): TICKET-012 propagate request context instead of context.Background()

- Updated AgentService interface to accept context.Context parameter in all methods
- Replaced context.Background() calls with proper ctx parameter in agent.go
- Updated AgentGroupService interface to accept context.Context parameter
- Replaced context.Background() calls with proper ctx parameter in agent_group.go
- Updated handler methods to pass r.Context() to service methods
- Context now properly propagates through request lifecycle for timeout/cancellation
- Improved request tracing and cancellation behavior
This commit is contained in:
shankar0123
2026-03-27 21:35:22 -04:00
parent 3e5cc86c5a
commit 200bdf990f
11 changed files with 413 additions and 81 deletions
+32
View File
@@ -6,6 +6,8 @@ import (
"log/slog"
"os/exec"
"time"
"github.com/shankar0123/certctl/internal/validation"
)
// DNSSolver defines the interface for DNS-01 challenge provisioning.
@@ -55,6 +57,16 @@ func (s *ScriptDNSSolver) Present(ctx context.Context, domain, token, keyAuth st
return fmt.Errorf("DNS present script not configured")
}
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_acme-challenge." + domain
s.Logger.Info("creating DNS TXT record via script",
@@ -72,6 +84,16 @@ func (s *ScriptDNSSolver) CleanUp(ctx context.Context, domain, token, keyAuth st
return nil
}
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_acme-challenge." + domain
s.Logger.Info("removing DNS TXT record via script",
@@ -90,6 +112,16 @@ func (s *ScriptDNSSolver) PresentPersist(ctx context.Context, domain, token, rec
return fmt.Errorf("DNS present script not configured")
}
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_validation-persist." + domain
s.Logger.Info("creating persistent DNS TXT record via script",
+133
View File
@@ -193,3 +193,136 @@ echo "FQDN=$CERTCTL_DNS_FQDN" > ` + outputFile + `
}
})
}
// Security tests for DNS injection prevention
func TestScriptDNSSolver_Present_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
tests := []struct {
name string
domain string
}{
{
name: "domain with command injection semicolon",
domain: "example.com; rm -rf /",
},
{
name: "domain with backtick injection",
domain: "example.com`whoami`",
},
{
name: "domain with command substitution",
domain: "example.com$(whoami)",
},
{
name: "domain with pipe injection",
domain: "example.com | cat /etc/passwd",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.Present(ctx, tt.domain, "test-token", "test-key-auth")
if err == nil {
t.Fatalf("expected error for invalid domain: %s", tt.domain)
}
})
}
}
func TestScriptDNSSolver_Present_RejectInvalidToken(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
tests := []struct {
name string
token string
}{
{
name: "token with command injection",
token: "token$(whoami)",
},
{
name: "token with backtick injection",
token: "token`id`",
},
{
name: "token with semicolon",
token: "token;malicious",
},
{
name: "token with pipe",
token: "token|cat",
},
{
name: "token with space",
token: "token value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.Present(ctx, "example.com", tt.token, "test-key-auth")
if err == nil {
t.Fatalf("expected error for invalid token: %s", tt.token)
}
})
}
}
func TestScriptDNSSolver_CleanUp_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "cleanup.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver("", scriptPath, logger)
err := solver.CleanUp(ctx, "example.com; rm -rf /", "test-token", "test-key-auth")
if err == nil {
t.Fatal("expected error for command injection in domain")
}
}
func TestScriptDNSSolver_PresentPersist_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.PresentPersist(ctx, "example.com`whoami`", "test-token", "letsencrypt.org; accounturi=https://example.com/acct/1")
if err == nil {
t.Fatal("expected error for command injection in domain")
}
}
func TestScriptDNSSolver_PresentPersist_RejectInvalidToken(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.PresentPersist(ctx, "example.com", "token$(whoami)", "letsencrypt.org; accounturi=https://example.com/acct/1")
if err == nil {
t.Fatal("expected error for command injection in token")
}
}
+12 -6
View File
@@ -97,22 +97,28 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("sign_script is required")
}
// Verify sign_script exists and is executable
if _, err := os.Stat(cfg.SignScript); err != nil {
// Verify sign_script exists and is a regular file
if info, err := os.Stat(cfg.SignScript); err != nil {
return fmt.Errorf("sign_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("sign_script must be a regular file, got %s", info.Mode())
}
// Verify revoke_script exists if specified
// Verify revoke_script exists and is a regular file if specified
if cfg.RevokeScript != "" {
if _, err := os.Stat(cfg.RevokeScript); err != nil {
if info, err := os.Stat(cfg.RevokeScript); err != nil {
return fmt.Errorf("revoke_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("revoke_script must be a regular file, got %s", info.Mode())
}
}
// Verify crl_script exists if specified
// Verify crl_script exists and is a regular file if specified
if cfg.CRLScript != "" {
if _, err := os.Stat(cfg.CRLScript); err != nil {
if info, err := os.Stat(cfg.CRLScript); err != nil {
return fmt.Errorf("crl_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("crl_script must be a regular file, got %s", info.Mode())
}
}
@@ -556,3 +556,68 @@ func generateMockCertPEM() string {
Bytes: certBytes,
}))
}
// Security tests for script path validation
func TestOpenSSLConnector_ValidateConfig_RejectNonRegularFile(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
// Try to use a directory as a script path
tmpDir := t.TempDir()
config := &openssl.Config{
SignScript: tmpDir, // This is a directory, not a regular file
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when sign_script is not a regular file")
}
}
func TestOpenSSLConnector_ValidateConfig_ValidateRevokeScriptPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
signScript := filepath.Join(tmpDir, "sign.sh")
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
// Try to use a nonexistent file as revoke_script
config := &openssl.Config{
SignScript: signScript,
RevokeScript: "/nonexistent/revoke.sh",
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when revoke_script is nonexistent")
}
}
func TestOpenSSLConnector_ValidateConfig_ValidateCRLScriptPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
signScript := filepath.Join(tmpDir, "sign.sh")
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
// Try to use a directory as crl_script
config := &openssl.Config{
SignScript: signScript,
CRLScript: tmpDir, // This is a directory, not a regular file
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when crl_script is not a regular file")
}
}
+3 -3
View File
@@ -142,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Validate Apache configuration before reload
c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("Apache validation failed", "error", err, "output", string(output))
@@ -156,7 +156,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Graceful reload
c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
if output, err := reloadCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output))
c.logger.Error("Apache reload failed", "error", err, "output", string(output))
@@ -196,7 +196,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
startTime := time.Now()
// Validate Apache configuration
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("validation failed", "error", err)
+15 -4
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/shankar0123/certctl/internal/connector/target"
"github.com/shankar0123/certctl/internal/validation"
)
// Config represents the HAProxy deployment target configuration.
@@ -53,12 +54,22 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("HAProxy reload_command is required")
}
// Validate commands to prevent injection attacks
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
return fmt.Errorf("invalid reload_command: %w", err)
}
if cfg.ValidateCommand != "" {
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
return fmt.Errorf("invalid validate_command: %w", err)
}
}
c.logger.Info("validating HAProxy configuration",
"pem_path", cfg.PEMPath)
// Verify validate command works if provided
if cfg.ValidateCommand != "" {
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
if err := cmd.Run(); err != nil {
c.logger.Warn("HAProxy config validation failed during config check",
"error", err,
@@ -114,7 +125,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Validate HAProxy configuration if validate command is configured
if c.config.ValidateCommand != "" {
c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("HAProxy validation failed", "error", err, "output", string(output))
@@ -129,7 +140,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Reload HAProxy
c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
if output, err := reloadCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output))
c.logger.Error("HAProxy reload failed", "error", err, "output", string(output))
@@ -169,7 +180,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
// Validate HAProxy configuration if command provided
if c.config.ValidateCommand != "" {
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("validation failed", "error", err)
@@ -377,3 +377,85 @@ func TestNginxConnector_ValidateDeployment_ValidateCommandFails(t *testing.T) {
t.Fatal("expected invalid result")
}
}
// Security tests for command injection prevention
func TestNginxConnector_ValidateConfig_RejectCommandInjectionSemicolon(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "nginx; rm -rf /", // Command injection attempt
ValidateCommand: "true",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command injection in reload_command")
}
}
func TestNginxConnector_ValidateConfig_RejectCommandInjectionPipe(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "true",
ValidateCommand: "nginx -t | cat /etc/passwd", // Command injection attempt
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command injection in validate_command")
}
}
func TestNginxConnector_ValidateConfig_RejectCommandSubstitution(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "echo $(whoami)",
ValidateCommand: "true",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command substitution in reload_command")
}
}
func TestNginxConnector_ValidateConfig_RejectBackticks(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "true",
ValidateCommand: "nginx -t `whoami`",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for backtick injection in validate_command")
}
}